[Diffusers] [SDv1.5 - Text-to-Image Finetuning] Choose the best checkpoint for inference. #7248
Replies: 1 comment
-
| 
 Validation losses are also a slightly uncertain metric for diffusion models since we're sampling random time steps, which impacts the loss quite a bit. In my experience, considering "FID + CLIP scores" on your validation subset is a good metric. If you want to do that quickly even CMMD could be nice. But do note that CMMD works by using 30k samples randomly drawn from the MSCOCO validation set. MSCOCO images are not of particularly high resolution. So, if your pipeline generates too high-res images, then CMMD metric might be faulty. So, it's best to have your own validation set upon which you could compute metrics like CMMD, FID, etc. | 
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Search before asking:
, as well as examined the code provided by Huggingface. Additionally, I have explored related resources concerning the process of selecting the optimal checkpoint (UNET checkpoint) for Stable Diffusion after finetuning:
Context:
Version of packages used:
Purpose of training:
Command for training:
accelerate launch --mixed_precision="fp16" train_text_to_image.py --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" --use_ema --resolution=224 --center_crop --random_flip --train_batch_size=128 --gradient_accumulation_steps=1 --gradient_checkpointing --num_train_epochs=10 --learning_rate=1e-05 --max_grad_norm=1 --lr_scheduler="constant" --lr_warmup_steps=0 --output_dir="Testing" --image_column="image_path" --enable_xformers_memory_efficient_attention --snr_gamma=5 --validation_epochs=1 --validation_prompts="Some qualitative prompts to test"Results:
Question:
After training my model for approximately 10 epochs and logging a checkpoint every 1500 steps, I am uncertain about the best checkpoint for inference. Typically, I would consider both the train loss and the validation loss to select a few optimal checkpoints for further analysis. However, the code for diffusers does not include validation loss, and I am concerned that selecting the final checkpoint might lead to overfitting issues.
Qualitative method:
Quantitative method:
CLIP evaluation between original images and images generated from corresponding texts (comparing image embeddings):
CLIP evaluation between generated images and corresponding texts (comparing image and textual embeddings):
FID evaluation between a generated collection of images based on texts and a collection of original images (FID calculation using torchmetrics):
In conclusion, the absence of validation loss implementation or discussion prompts me to seek guidance from our community on the optimal approach for selecting checkpoints. Despite conducting thorough research on this matter, my knowledge remains limited, and I am eager to explore additional perspectives and insights on this issue.
Related figures:
Beta Was this translation helpful? Give feedback.
All reactions