44from functools import partial
55from collections import namedtuple
66
7+ from beartype import beartype
8+
79import torch
810from torch import nn , einsum
911import torch .nn .functional as F
1012from torch .fft import fft2 , ifft2
1113
12- from einops import rearrange , reduce , pack
14+ from einops import rearrange , reduce
1315from einops .layers .torch import Rearrange
1416
1517from tqdm .auto import tqdm
@@ -224,15 +226,17 @@ def forward(self, x, c):
224226
225227# model
226228
229+ @beartype
227230class Unet (nn .Module ):
228231 def __init__ (
229232 self ,
230233 dim ,
231234 image_size ,
232235 init_dim = None ,
233236 out_dim = None ,
234- dim_mults = (1 , 2 , 4 , 8 ),
237+ dim_mults : tuple = (1 , 2 , 4 , 8 ),
235238 channels = 3 ,
239+ full_self_attn : tuple = (False , False , False , True ),
236240 self_condition = False ,
237241 resnet_block_groups = 8 ,
238242 conditioning_klass = Conditioning ,
@@ -272,6 +276,7 @@ def __init__(
272276 # layers
273277
274278 num_resolutions = len (in_out )
279+ assert len (full_self_attn ) == num_resolutions
275280
276281 self .conditioners = nn .ModuleList ([])
277282
@@ -283,15 +288,17 @@ def __init__(
283288
284289 curr_fmap_size = image_size
285290
286- for ind , (dim_in , dim_out ) in enumerate (in_out ):
291+ for ind , (( dim_in , dim_out ), full_attn ) in enumerate (zip ( in_out , full_self_attn ) ):
287292 is_last = ind >= (num_resolutions - 1 )
293+ attn_klass = Attention if full_attn else LinearAttention
288294
289295 self .conditioners .append (conditioning_klass (curr_fmap_size , dim_in ))
290296
297+
291298 self .downs .append (nn .ModuleList ([
292299 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
293300 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
294- Residual (LinearAttention (dim_in )),
301+ Residual (attn_klass (dim_in )),
295302 Downsample (dim_in , dim_out ) if not is_last else nn .Conv2d (dim_in , dim_out , 3 , padding = 1 )
296303 ]))
297304
@@ -314,15 +321,16 @@ def __init__(
314321
315322 self .ups = nn .ModuleList ([])
316323
317- for ind , (dim_in , dim_out ) in enumerate (reversed (in_out )):
324+ for ind , (( dim_in , dim_out ), full_attn ) in enumerate (zip ( reversed (in_out ), reversed ( full_self_attn ) )):
318325 is_last = ind == (len (in_out ) - 1 )
326+ attn_klass = Attention if full_attn else LinearAttention
319327
320328 skip_connect_dim = dim_in * (2 if self .skip_connect_condition_fmaps else 1 )
321329
322330 self .ups .append (nn .ModuleList ([
323331 block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
324332 block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
325- Residual (LinearAttention (dim_out )),
333+ Residual (attn_klass (dim_out )),
326334 Upsample (dim_out , dim_in ) if not is_last else nn .Conv2d (dim_out , dim_in , 3 , padding = 1 )
327335 ]))
328336
0 commit comments