@@ -147,9 +147,9 @@ def transforms(examples):
147147
148148 accelerator .wait_for_everyone ()
149149
150- # Generate a sample image for visual inspection
150+ # Generate sample images for visual inspection
151151 if accelerator .is_main_process :
152- with torch . no_grad () :
152+ if epoch % args . save_images_epochs == 0 or epoch == args . num_epochs - 1 :
153153 pipeline = DDPMPipeline (
154154 unet = accelerator .unwrap_model (ema_model .averaged_model if args .use_ema else model ),
155155 scheduler = noise_scheduler ,
@@ -159,9 +159,11 @@ def transforms(examples):
159159 # run pipeline in inference (sample random noise and denoise)
160160 images = pipeline (generator = generator , batch_size = args .eval_batch_size , output_type = "numpy" )["sample" ]
161161
162- # denormalize the images and save to tensorboard
163- images_processed = (images * 255 ).round ().astype ("uint8" )
164- accelerator .trackers [0 ].writer .add_images ("test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch )
162+ # denormalize the images and save to tensorboard
163+ images_processed = (images * 255 ).round ().astype ("uint8" )
164+ accelerator .trackers [0 ].writer .add_images (
165+ "test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
166+ )
165167
166168 if epoch % args .save_model_epochs == 0 or epoch == args .num_epochs - 1 :
167169 # save the model
@@ -184,7 +186,8 @@ def transforms(examples):
184186 parser .add_argument ("--train_batch_size" , type = int , default = 16 )
185187 parser .add_argument ("--eval_batch_size" , type = int , default = 16 )
186188 parser .add_argument ("--num_epochs" , type = int , default = 100 )
187- parser .add_argument ("--save_model_epochs" , type = int , default = 5 )
189+ parser .add_argument ("--save_images_epochs" , type = int , default = 10 )
190+ parser .add_argument ("--save_model_epochs" , type = int , default = 10 )
188191 parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
189192 parser .add_argument ("--learning_rate" , type = float , default = 1e-4 )
190193 parser .add_argument ("--lr_scheduler" , type = str , default = "cosine" )
0 commit comments