2525
2626from accelerate import Accelerator
2727
28+ from denoising_diffusion_pytorch .attend import Attend
2829from denoising_diffusion_pytorch .fid_evaluation import FIDEvaluation
2930
3031from denoising_diffusion_pytorch .version import __version__
@@ -43,6 +44,11 @@ def default(val, d):
4344 return val
4445 return d () if callable (d ) else d
4546
47+ def cast_tuple (t , length = 1 ):
48+ if isinstance (t , tuple ):
49+ return t
50+ return ((t ,) * length )
51+
4652def identity (t , * args , ** kwargs ):
4753 return t
4854
@@ -77,14 +83,6 @@ def unnormalize_to_zero_to_one(t):
7783
7884# small helper modules
7985
80- class Residual (nn .Module ):
81- def __init__ (self , fn ):
82- super ().__init__ ()
83- self .fn = fn
84-
85- def forward (self , x , * args , ** kwargs ):
86- return self .fn (x , * args , ** kwargs ) + x
87-
8886def Upsample (dim , dim_out = None ):
8987 return nn .Sequential (
9088 nn .Upsample (scale_factor = 2 , mode = 'nearest' ),
@@ -105,16 +103,6 @@ def __init__(self, dim):
105103 def forward (self , x ):
106104 return F .normalize (x , dim = 1 ) * self .g * (x .shape [1 ] ** 0.5 )
107105
108- class PreNorm (nn .Module ):
109- def __init__ (self , dim , fn ):
110- super ().__init__ ()
111- self .fn = fn
112- self .norm = RMSNorm (dim )
113-
114- def forward (self , x ):
115- x = self .norm (x )
116- return self .fn (x )
117-
118106# sinusoidal positional embeds
119107
120108class SinusoidalPosEmb (nn .Module ):
@@ -195,11 +183,18 @@ def forward(self, x, time_emb = None):
195183 return h + self .res_conv (x )
196184
197185class LinearAttention (nn .Module ):
198- def __init__ (self , dim , heads = 4 , dim_head = 32 ):
186+ def __init__ (
187+ self ,
188+ dim ,
189+ heads = 4 ,
190+ dim_head = 32
191+ ):
199192 super ().__init__ ()
200193 self .scale = dim_head ** - 0.5
201194 self .heads = heads
202195 hidden_dim = dim_head * heads
196+
197+ self .norm = RMSNorm (dim )
203198 self .to_qkv = nn .Conv2d (dim , hidden_dim * 3 , 1 , bias = False )
204199
205200 self .to_out = nn .Sequential (
@@ -209,6 +204,9 @@ def __init__(self, dim, heads = 4, dim_head = 32):
209204
210205 def forward (self , x ):
211206 b , c , h , w = x .shape
207+
208+ x = self .norm (x )
209+
212210 qkv = self .to_qkv (x ).chunk (3 , dim = 1 )
213211 q , k , v = map (lambda t : rearrange (t , 'b (h c) x y -> b h c (x y)' , h = self .heads ), qkv )
214212
@@ -224,25 +222,32 @@ def forward(self, x):
224222 return self .to_out (out )
225223
226224class Attention (nn .Module ):
227- def __init__ (self , dim , heads = 4 , dim_head = 32 ):
225+ def __init__ (
226+ self ,
227+ dim ,
228+ heads = 4 ,
229+ dim_head = 32 ,
230+ flash = False
231+ ):
228232 super ().__init__ ()
229- self .scale = dim_head ** - 0.5
230233 self .heads = heads
231234 hidden_dim = dim_head * heads
232235
236+ self .norm = RMSNorm (dim )
237+ self .attend = Attend (flash = flash )
238+
233239 self .to_qkv = nn .Conv2d (dim , hidden_dim * 3 , 1 , bias = False )
234240 self .to_out = nn .Conv2d (hidden_dim , dim , 1 )
235241
236242 def forward (self , x ):
237243 b , c , h , w = x .shape
238- qkv = self .to_qkv (x ).chunk (3 , dim = 1 )
239- q , k , v = map (lambda t : rearrange (t , 'b (h c) x y -> b h c (x y)' , h = self .heads ), qkv )
240244
241- q = q * self .scale
245+ x = self .norm ( x )
242246
243- sim = einsum ('b h d i, b h d j -> b h i j' , q , k )
244- attn = sim .softmax (dim = - 1 )
245- out = einsum ('b h i j, b h d j -> b h i d' , attn , v )
247+ qkv = self .to_qkv (x ).chunk (3 , dim = 1 )
248+ q , k , v = map (lambda t : rearrange (t , 'b (h c) x y -> b h (x y) c' , h = self .heads ), qkv )
249+
250+ out = self .attend (q , k , v )
246251
247252 out = rearrange (out , 'b h (x y) d -> b (h d) x y' , x = h , y = w )
248253 return self .to_out (out )
@@ -255,14 +260,16 @@ def __init__(
255260 dim ,
256261 init_dim = None ,
257262 out_dim = None ,
258- dim_mults = (1 , 2 , 4 , 8 ),
263+ dim_mults = (1 , 2 , 4 , 8 ),
259264 channels = 3 ,
260265 self_condition = False ,
261266 resnet_block_groups = 8 ,
262267 learned_variance = False ,
263268 learned_sinusoidal_cond = False ,
264269 random_fourier_features = False ,
265- learned_sinusoidal_dim = 16
270+ learned_sinusoidal_dim = 16 ,
271+ full_attn = (False , False , False , True ),
272+ flash_attn = False
266273 ):
267274 super ().__init__ ()
268275
@@ -300,34 +307,45 @@ def __init__(
300307 nn .Linear (time_dim , time_dim )
301308 )
302309
310+ # attention
311+
312+ full_attn = cast_tuple (full_attn , length = len (dim_mults ))
313+ assert len (full_attn ) == len (dim_mults )
314+
315+ FullAttention = partial (Attention , flash = flash_attn )
316+
303317 # layers
304318
305319 self .downs = nn .ModuleList ([])
306320 self .ups = nn .ModuleList ([])
307321 num_resolutions = len (in_out )
308322
309- for ind , (dim_in , dim_out ) in enumerate (in_out ):
323+ for ind , (( dim_in , dim_out ), layer_full_attn ) in enumerate (zip ( in_out , full_attn ) ):
310324 is_last = ind >= (num_resolutions - 1 )
311325
326+ attn_klass = FullAttention if layer_full_attn else LinearAttention
327+
312328 self .downs .append (nn .ModuleList ([
313329 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
314330 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
315- Residual ( PreNorm ( dim_in , LinearAttention ( dim_in )) ),
331+ attn_klass ( dim_in ),
316332 Downsample (dim_in , dim_out ) if not is_last else nn .Conv2d (dim_in , dim_out , 3 , padding = 1 )
317333 ]))
318334
319335 mid_dim = dims [- 1 ]
320336 self .mid_block1 = block_klass (mid_dim , mid_dim , time_emb_dim = time_dim )
321- self .mid_attn = Residual ( PreNorm ( mid_dim , Attention ( mid_dim )) )
337+ self .mid_attn = FullAttention ( mid_dim )
322338 self .mid_block2 = block_klass (mid_dim , mid_dim , time_emb_dim = time_dim )
323339
324- for ind , (dim_in , dim_out ) in enumerate (reversed (in_out )):
340+ for ind , (( dim_in , dim_out ), layer_full_attn ) in enumerate (zip ( reversed (in_out ), reversed ( full_attn ) )):
325341 is_last = ind == (len (in_out ) - 1 )
326342
343+ attn_klass = FullAttention if layer_full_attn else LinearAttention
344+
327345 self .ups .append (nn .ModuleList ([
328346 block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
329347 block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
330- Residual ( PreNorm ( dim_out , LinearAttention ( dim_out )) ),
348+ attn_klass ( dim_out ),
331349 Upsample (dim_out , dim_in ) if not is_last else nn .Conv2d (dim_out , dim_in , 3 , padding = 1 )
332350 ]))
333351
@@ -354,13 +372,13 @@ def forward(self, x, time, x_self_cond = None):
354372 h .append (x )
355373
356374 x = block2 (x , t )
357- x = attn (x )
375+ x = attn (x ) + x
358376 h .append (x )
359377
360378 x = downsample (x )
361379
362380 x = self .mid_block1 (x , t )
363- x = self .mid_attn (x )
381+ x = self .mid_attn (x ) + x
364382 x = self .mid_block2 (x , t )
365383
366384 for block1 , block2 , attn , upsample in self .ups :
@@ -369,7 +387,7 @@ def forward(self, x, time, x_self_cond = None):
369387
370388 x = torch .cat ((x , h .pop ()), dim = 1 )
371389 x = block2 (x , t )
372- x = attn (x )
390+ x = attn (x ) + x
373391
374392 x = upsample (x )
375393
0 commit comments