1616from tqdm import tqdm
1717from einops import rearrange
1818
19+ try :
20+ from apex import amp
21+ APEX_AVAILABLE = True
22+ except :
23+ APEX_AVAILABLE = False
24+
1925# constants
2026
2127SAVE_AND_SAMPLE_EVERY = 1000
@@ -37,6 +43,13 @@ def cycle(dl):
3743 for data in dl :
3844 yield data
3945
46+ def loss_backwards (fp16 , loss , optimizer , ** kwargs ):
47+ if fp16 :
48+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
49+ scaled_loss .backward (** kwargs )
50+ else :
51+ loss .backward (** kwargs )
52+
4053# small helper modules
4154
4255class EMA ():
@@ -107,7 +120,7 @@ def forward(self, x):
107120# building block modules
108121
109122class Block (nn .Module ):
110- def __init__ (self , dim , dim_out , groups = 32 ):
123+ def __init__ (self , dim , dim_out , groups = 8 ):
111124 super ().__init__ ()
112125 self .block = nn .Sequential (
113126 nn .Conv2d (dim , dim_out , 3 , padding = 1 ),
@@ -118,7 +131,7 @@ def forward(self, x):
118131 return self .block (x )
119132
120133class ResnetBlock (nn .Module ):
121- def __init__ (self , dim , dim_out , * , time_emb_dim , groups = 32 ):
134+ def __init__ (self , dim , dim_out , * , time_emb_dim , groups = 8 ):
122135 super ().__init__ ()
123136 self .mlp = nn .Sequential (
124137 Mish (),
@@ -157,7 +170,7 @@ def forward(self, x):
157170# model
158171
159172class Unet (nn .Module ):
160- def __init__ (self , dim , out_dim = None , dim_mults = (1 , 2 , 4 , 8 ), groups = 32 ):
173+ def __init__ (self , dim , out_dim = None , dim_mults = (1 , 2 , 4 , 8 ), groups = 8 ):
161174 super ().__init__ ()
162175 dims = [3 , * map (lambda m : dim * m , dim_mults )]
163176 in_out = list (zip (dims [:- 1 ], dims [1 :]))
@@ -178,6 +191,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
178191
179192 self .downs .append (nn .ModuleList ([
180193 ResnetBlock (dim_in , dim_out , time_emb_dim = dim ),
194+ ResnetBlock (dim_out , dim_out , time_emb_dim = dim ),
181195 Residual (Rezero (LinearAttention (dim_out ))),
182196 Downsample (dim_out ) if not is_last else nn .Identity ()
183197 ]))
@@ -192,6 +206,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
192206
193207 self .ups .append (nn .ModuleList ([
194208 ResnetBlock (dim_out * 2 , dim_in , time_emb_dim = dim ),
209+ ResnetBlock (dim_in , dim_in , time_emb_dim = dim ),
195210 Residual (Rezero (LinearAttention (dim_in ))),
196211 Upsample (dim_in ) if not is_last else nn .Identity ()
197212 ]))
@@ -208,8 +223,9 @@ def forward(self, x, time):
208223
209224 h = []
210225
211- for resnet , attn , downsample in self .downs :
226+ for resnet , resnet2 , attn , downsample in self .downs :
212227 x = resnet (x , t )
228+ x = resnet2 (x , t )
213229 x = attn (x )
214230 h .append (x )
215231 x = downsample (x )
@@ -218,9 +234,10 @@ def forward(self, x, time):
218234 x = self .mid_attn (x )
219235 x = self .mid_block2 (x , t )
220236
221- for resnet , attn , upsample in self .ups :
237+ for resnet , resnet2 , attn , upsample in self .ups :
222238 x = torch .cat ((x , h .pop ()), dim = 1 )
223239 x = resnet (x , t )
240+ x = resnet2 (x , t )
224241 x = attn (x )
225242 x = upsample (x )
226243
@@ -417,23 +434,40 @@ def __init__(
417434 train_lr = 2e-5 ,
418435 train_num_steps = 100000 ,
419436 gradient_accumulate_every = 2 ,
437+ fp16 = False
420438 ):
421439 super ().__init__ ()
422440 self .model = diffusion_model
441+ self .ema = EMA (ema_decay )
442+ self .ema_model = copy .deepcopy (self .model )
423443
424444 self .image_size = image_size
425445 self .gradient_accumulate_every = gradient_accumulate_every
426446 self .train_num_steps = train_num_steps
427447
428- self .ema = EMA (ema_decay )
429- self .ema_model = copy .deepcopy (self .model )
430-
431448 self .ds = Dataset (folder , image_size )
432449 self .dl = cycle (data .DataLoader (self .ds , batch_size = train_batch_size , shuffle = True , pin_memory = True ))
433450 self .opt = Adam (diffusion_model .parameters (), lr = train_lr )
434451
435452 self .step = 0
436453
454+ assert not fp16 or fp16 and APEX_AVAILABLE , 'Apex must be installed in order for mixed precision training to be turned on'
455+
456+ self .fp16 = fp16
457+ if fp16 :
458+ (self .model , self .ema_model ), self .opt = amp .initialize ([self .model , self .ema_model ], self .opt , opt_level = 'O1' )
459+
460+ self .reset_parameters ()
461+
462+ def reset_parameters (self ):
463+ self .ema_model .load_state_dict (self .model .state_dict ())
464+
465+ def step_ema (self ):
466+ if self .step < 2000 :
467+ self .reset_parameters ()
468+ return
469+ self .ema .update_model_average (self .ema_model , self .model )
470+
437471 def save (self , milestone ):
438472 data = {
439473 'step' : self .step ,
@@ -450,18 +484,20 @@ def load(self, milestone):
450484 self .ema_model .load_state_dict (data ['ema' ])
451485
452486 def train (self ):
487+ backwards = partial (loss_backwards , self .fp16 )
488+
453489 while self .step < self .train_num_steps :
454490 for i in range (self .gradient_accumulate_every ):
455491 data = next (self .dl ).cuda ()
456492 loss = self .model (data )
457493 print (f'{ self .step } : { loss .item ()} ' )
458- (loss / self .gradient_accumulate_every ). backward ( )
494+ backwards (loss / self .gradient_accumulate_every , self . opt )
459495
460496 self .opt .step ()
461497 self .opt .zero_grad ()
462498
463499 if self .step % UPDATE_EMA_EVERY == 0 :
464- self .ema . update_model_average ( self . ema_model , self . model )
500+ self .step_ema ( )
465501
466502 if self .step % SAVE_AND_SAMPLE_EVERY == 0 :
467503 milestone = self .step // SAVE_AND_SAMPLE_EVERY
0 commit comments