Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/unconditional_image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ A full training run takes 2 hours on 4xV100 GPUs.

<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />

To obtain a quantitative evaluation of generated images, you can compute FID by passing `--compute_fid`.
### Training with multiple GPUs

`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
Expand Down
2 changes: 2 additions & 0 deletions examples/unconditional_image_generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
accelerate>=0.16.0
torchvision
datasets
torchmetrics
torch-fidelity
68 changes: 66 additions & 2 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datasets
import torch
import torch.nn.functional as F
from torchmetrics.image.fid import FrechetInceptionDistance
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
Expand Down Expand Up @@ -129,6 +130,15 @@ def parse_args():
parser.add_argument(
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
)
parser.add_argument(
"--num_samples_to_evaluate", type=int, default=10000, help="Number of samples to generate for quantity assessment (e.g. FID)"
)
parser.add_argument(
"--compute_fid", action="store_true",
help=(
"If given, then it computes FID at each checkpointing step. Using `--num_samples_to_evaluate` determines the total number of samples to generate from diffusion model."
)
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
Expand Down Expand Up @@ -536,6 +546,25 @@ def transform_images(examples):
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

if args.compute_fid:
# FID computation happens only on main_process
if accelerator.is_main_process:
fid = FrechetInceptionDistance(normalize=True,
reset_real_features=False,
sync_on_compute=False).to(device=accelerator.device)

# update FID for real images
fid_prog_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_main_process)
fid_prog_bar.set_description(f"Update FID - real images")
for step, batch in enumerate(train_dataloader):
imgs = batch["input"].to(weight_dtype)
gathered_imgs = accelerator.gather(imgs)
if accelerator.is_main_process:
fid.update(gathered_imgs, real=True)

del gathered_imgs
fid_prog_bar.update(1)

# Train!
for epoch in range(first_epoch, args.num_epochs):
model.train()
Expand Down Expand Up @@ -593,8 +622,8 @@ def transform_images(examples):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
Comment on lines +625 to +626
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe both of these could be clubbed together?

# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
Expand All @@ -618,6 +647,41 @@ def transform_images(examples):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

# Compute FID
if args.compute_fid:
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())

generator = torch.Generator(device=accelerator.device).manual_seed(accelerator.process_index)
pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
)
num_batches = int(args.num_samples_to_evaluate / (args.eval_batch_size * accelerator.num_processes))
for step in range(num_batches):
generated = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
num_inference_steps=args.ddpm_num_inference_steps,
output_type="np",
).images
generated = torch.tensor(generated, device=accelerator.device, dtype=weight_dtype)
generated = generated.permute(0, 3, 1, 2)
g_generated = accelerator.gather(generated)
if accelerator.is_main_process:
fid.update(g_generated, real=False)
del g_generated

if accelerator.is_main_process:
fid_value = float(fid.compute())
fid.reset()
logs.update({"FID": fid_value})

if args.use_ema:
ema_model.restore(unet.parameters())

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
Expand Down