1+ #!/usr/bin/env python
2+ # coding=utf-8
3+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+ #
5+ # Licensed under the Apache License, Version 2.0 (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+
116import argparse
217import contextlib
318import gc
3348from diffusers import AutoencoderKL
3449from diffusers .optimization import get_scheduler
3550from diffusers .training_utils import EMAModel
36- from diffusers .utils import check_min_version , is_wandb_available
51+ from diffusers .utils import check_min_version , is_wandb_available , make_image_grid
3752from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
3853from diffusers .utils .import_utils import is_xformers_available
3954from diffusers .utils .torch_utils import is_compiled_module
4358 import wandb
4459
4560# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
46- check_min_version ("0.30 .0.dev0" )
61+ # check_min_version("0.33 .0.dev0")
4762
4863logger = get_logger (__name__ )
4964
5065
51- def image_grid (imgs , rows , cols ):
52- assert len (imgs ) == rows * cols
53-
54- w , h = imgs [0 ].size
55- grid = Image .new ("RGB" , size = (cols * w , rows * h ))
56-
57- for i , img in enumerate (imgs ):
58- grid .paste (img , box = (i % cols * w , i // cols * h ))
59- return grid
60-
61-
6266@torch .no_grad ()
6367def log_validation (
6468 vae , args , accelerator , weight_dtype , step , is_final_validation = False
@@ -111,7 +115,7 @@ def log_validation(
111115 }
112116 )
113117 else :
114- logger .warn (f"image logging not implemented for { tracker .gen_images } " )
118+ logger .warn (f"image logging not implemented for { tracker .name } " )
115119
116120 gc .collect ()
117121 torch .cuda .empty_cache ()
@@ -123,7 +127,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
123127 img_str = ""
124128 if images is not None :
125129 img_str = "You can find some example images below.\n \n "
126- image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f"images.png" ))
130+ make_image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f"images.png" ))
127131 img_str += f"\n "
128132
129133 model_description = f"""
@@ -875,23 +879,19 @@ def load_model_hook(models, input_dir):
875879 for step , batch in enumerate (train_dataloader ):
876880 # Convert images to latent space and reconstruct from them
877881 targets = batch ["pixel_values" ].to (dtype = weight_dtype )
878- if accelerator .num_processes > 1 :
879- posterior = vae .module .encode (targets ).latent_dist
880- else :
881- posterior = vae .encode (targets ).latent_dist
882+ posterior = accelerator .unwrap_model (vae ).encode (targets ).latent_dist
882883 latents = posterior .sample ()
883- if accelerator .num_processes > 1 :
884- reconstructions = vae .module .decode (latents ).sample
885- else :
886- reconstructions = vae .decode (latents ).sample
884+ reconstructions = accelerator .unwrap_model (vae ).decode (latents ).sample
887885
888886 if (step // args .gradient_accumulation_steps ) % 2 == 0 or global_step < args .disc_start :
889887 with accelerator .accumulate (vae ):
890888 # reconstruction loss. Pixel level differences between input vs output
891889 if args .rec_loss == "l2" :
892890 rec_loss = F .mse_loss (reconstructions .float (), targets .float (), reduction = "none" )
893- else :
891+ elif args . rec_loss == "l1" :
894892 rec_loss = F .l1_loss (reconstructions .float (), targets .float (), reduction = "none" )
893+ else :
894+ raise ValueError (f"Invalid reconstruction loss type: { args .rec_loss } " )
895895 # perceptual loss. The high level feature mean squared error loss
896896 with torch .no_grad ():
897897 p_loss = perceptual_loss (reconstructions , targets )
0 commit comments