diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 2990b3abf3f5..12c4e99aa0a3 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -76,6 +76,7 @@ A full training run takes 2 hours on 4xV100 GPUs. +To obtain a quantitative evaluation of generated images, you can compute FID by passing `--compute_fid`. ### Training with multiple GPUs `accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch) diff --git a/examples/unconditional_image_generation/requirements.txt b/examples/unconditional_image_generation/requirements.txt index f366720afd11..a00afb37766c 100644 --- a/examples/unconditional_image_generation/requirements.txt +++ b/examples/unconditional_image_generation/requirements.txt @@ -1,3 +1,5 @@ accelerate>=0.16.0 torchvision datasets +torchmetrics +torch-fidelity \ No newline at end of file diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 1f5e1de240cb..aeb07cdd696f 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -11,6 +11,7 @@ import datasets import torch import torch.nn.functional as F +from torchmetrics.image.fid import FrechetInceptionDistance from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration @@ -129,6 +130,15 @@ def parse_args(): parser.add_argument( "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." ) + parser.add_argument( + "--num_samples_to_evaluate", type=int, default=10000, help="Number of samples to generate for quantity assessment (e.g. FID)" + ) + parser.add_argument( + "--compute_fid", action="store_true", + help=( + "If given, then it computes FID at each checkpointing step. Using `--num_samples_to_evaluate` determines the total number of samples to generate from diffusion model." + ) + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -536,6 +546,25 @@ def transform_images(examples): first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + if args.compute_fid: + # FID computation happens only on main_process + if accelerator.is_main_process: + fid = FrechetInceptionDistance(normalize=True, + reset_real_features=False, + sync_on_compute=False).to(device=accelerator.device) + + # update FID for real images + fid_prog_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_main_process) + fid_prog_bar.set_description(f"Update FID - real images") + for step, batch in enumerate(train_dataloader): + imgs = batch["input"].to(weight_dtype) + gathered_imgs = accelerator.gather(imgs) + if accelerator.is_main_process: + fid.update(gathered_imgs, real=True) + + del gathered_imgs + fid_prog_bar.update(1) + # Train! for epoch in range(first_epoch, args.num_epochs): model.train() @@ -593,8 +622,8 @@ def transform_images(examples): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) @@ -618,6 +647,41 @@ def transform_images(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + + # Compute FID + if args.compute_fid: + unet = accelerator.unwrap_model(model) + if args.use_ema: + ema_model.store(unet.parameters()) + ema_model.copy_to(unet.parameters()) + + generator = torch.Generator(device=accelerator.device).manual_seed(accelerator.process_index) + pipeline = DDPMPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + num_batches = int(args.num_samples_to_evaluate / (args.eval_batch_size * accelerator.num_processes)) + for step in range(num_batches): + generated = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + num_inference_steps=args.ddpm_num_inference_steps, + output_type="np", + ).images + generated = torch.tensor(generated, device=accelerator.device, dtype=weight_dtype) + generated = generated.permute(0, 3, 1, 2) + g_generated = accelerator.gather(generated) + if accelerator.is_main_process: + fid.update(g_generated, real=False) + del g_generated + + if accelerator.is_main_process: + fid_value = float(fid.compute()) + fid.reset() + logs.update({"FID": fid_value}) + + if args.use_ema: + ema_model.restore(unet.parameters()) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: