@@ -198,33 +198,49 @@ def forward(self, x):
198198
199199# conditioning class
200200
201- class FourierConditioning (nn .Module ):
202- def __init__ (self , dim ):
201+ class Conditioning (nn .Module ):
202+ def __init__ (self , fmap_size , dim ):
203203 super ().__init__ ()
204+ self .ff_parser_attn_map = nn .Parameter (torch .ones (dim , fmap_size , fmap_size ))
205+
204206 self .norm_input = LayerNorm (dim , bias = True )
205207 self .norm_condition = LayerNorm (dim , bias = True )
206208
207209 def forward (self , x , c ):
210+
211+ # ff-parser in the paper, for modulating out the high frequencies
212+
213+ dtype = x .dtype
214+ x = fft2 (x )
215+ x = x * self .ff_parser_attn_map
216+ x = ifft2 (x ).real
217+ x = x .type (dtype )
218+
219+ # eq 3 in paper
220+
208221 normed_x = self .norm_input (x )
209222 normed_c = self .norm_condition (c )
210- return (normed_x * normed_c ) * c # eq 3 in paper
223+ return (normed_x * normed_c ) * c
211224
212225# model
213226
214227class Unet (nn .Module ):
215228 def __init__ (
216229 self ,
217230 dim ,
231+ image_size ,
218232 init_dim = None ,
219233 out_dim = None ,
220234 dim_mults = (1 , 2 , 4 , 8 ),
221235 channels = 3 ,
222236 self_condition = False ,
223237 resnet_block_groups = 8 ,
224- conditioning_klass = FourierConditioning
238+ conditioning_klass = Conditioning
225239 ):
226240 super ().__init__ ()
227241
242+ self .image_size = image_size
243+
228244 # determine dimensions
229245
230246 self .channels = channels
@@ -262,10 +278,12 @@ def __init__(
262278
263279 self .downs = nn .ModuleList ([])
264280
281+ curr_fmap_size = image_size
282+
265283 for ind , (dim_in , dim_out ) in enumerate (in_out ):
266284 is_last = ind >= (num_resolutions - 1 )
267285
268- self .conditioners .append (conditioning_klass (dim_in ))
286+ self .conditioners .append (conditioning_klass (curr_fmap_size , dim_in ))
269287
270288 self .downs .append (nn .ModuleList ([
271289 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
@@ -274,6 +292,9 @@ def __init__(
274292 Downsample (dim_in , dim_out ) if not is_last else nn .Conv2d (dim_in , dim_out , 3 , padding = 1 )
275293 ]))
276294
295+ if not is_last :
296+ curr_fmap_size //= 2
297+
277298 # middle blocks
278299
279300 mid_dim = dims [- 1 ]
@@ -402,7 +423,6 @@ def __init__(
402423 self ,
403424 model ,
404425 * ,
405- image_size ,
406426 timesteps = 1000 ,
407427 sampling_timesteps = None ,
408428 objective = 'pred_noise' ,
@@ -415,7 +435,7 @@ def __init__(
415435 self .channels = self .model .channels
416436 self .self_condition = self .model .self_condition
417437
418- self .image_size = image_size
438+ self .image_size = model . image_size
419439
420440 self .objective = objective
421441
0 commit comments