Skip to content

Commit 06505ba

Browse files
committed
Less eval steps during training
1 parent 1345700 commit 06505ba

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

examples/train_unconditional.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def transforms(examples):
147147

148148
accelerator.wait_for_everyone()
149149

150-
# Generate a sample image for visual inspection
150+
# Generate sample images for visual inspection
151151
if accelerator.is_main_process:
152-
with torch.no_grad():
152+
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
153153
pipeline = DDPMPipeline(
154154
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
155155
scheduler=noise_scheduler,
@@ -159,9 +159,11 @@ def transforms(examples):
159159
# run pipeline in inference (sample random noise and denoise)
160160
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
161161

162-
# denormalize the images and save to tensorboard
163-
images_processed = (images * 255).round().astype("uint8")
164-
accelerator.trackers[0].writer.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
162+
# denormalize the images and save to tensorboard
163+
images_processed = (images * 255).round().astype("uint8")
164+
accelerator.trackers[0].writer.add_images(
165+
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
166+
)
165167

166168
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
167169
# save the model
@@ -184,7 +186,8 @@ def transforms(examples):
184186
parser.add_argument("--train_batch_size", type=int, default=16)
185187
parser.add_argument("--eval_batch_size", type=int, default=16)
186188
parser.add_argument("--num_epochs", type=int, default=100)
187-
parser.add_argument("--save_model_epochs", type=int, default=5)
189+
parser.add_argument("--save_images_epochs", type=int, default=10)
190+
parser.add_argument("--save_model_epochs", type=int, default=10)
188191
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
189192
parser.add_argument("--learning_rate", type=float, default=1e-4)
190193
parser.add_argument("--lr_scheduler", type=str, default="cosine")

0 commit comments

Comments
 (0)