1414# See the License for the specific language governing permissions and
1515
1616import argparse
17+ import contextlib
1718import gc
1819import hashlib
1920import itertools
21+ import json
2022import logging
2123import math
2224import os
3739from accelerate import Accelerator
3840from accelerate .logging import get_logger
3941from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
40- from huggingface_hub import create_repo , upload_folder
42+ from huggingface_hub import create_repo , hf_hub_download , upload_folder
4143from packaging import version
4244from peft import LoraConfig , set_peft_model_state_dict
4345from peft .utils import get_peft_model_state_dict
5557 AutoencoderKL ,
5658 DDPMScheduler ,
5759 DPMSolverMultistepScheduler ,
60+ EDMEulerScheduler ,
61+ EulerDiscreteScheduler ,
5862 StableDiffusionXLPipeline ,
5963 UNet2DConditionModel ,
6064)
7983logger = get_logger (__name__ )
8084
8185
86+ def determine_scheduler_type (pretrained_model_name_or_path , revision ):
87+ model_index_filename = "model_index.json"
88+ if os .path .isdir (pretrained_model_name_or_path ):
89+ model_index = os .path .join (pretrained_model_name_or_path , model_index_filename )
90+ else :
91+ model_index = hf_hub_download (
92+ repo_id = pretrained_model_name_or_path , filename = model_index_filename , revision = revision
93+ )
94+
95+ with open (model_index , "r" ) as f :
96+ scheduler_type = json .load (f )["scheduler" ][1 ]
97+ return scheduler_type
98+
99+
82100def save_model_card (
83101 repo_id : str ,
84102 use_dora : bool ,
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
370388 " `args.validation_prompt` multiple times: `args.num_validation_images`."
371389 ),
372390 )
391+ parser .add_argument (
392+ "--do_edm_style_training" ,
393+ action = "store_true" ,
394+ help = "Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364." ,
395+ )
373396 parser .add_argument (
374397 "--with_prior_preservation" ,
375398 default = False ,
@@ -1117,6 +1140,8 @@ def main(args):
11171140 "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
11181141 " Please use `huggingface-cli login` to authenticate with the Hub."
11191142 )
1143+ if args .do_edm_style_training and args .snr_gamma is not None :
1144+ raise ValueError ("Min-SNR formulation is not supported when conducting EDM-style training." )
11201145
11211146 logging_dir = Path (args .output_dir , args .logging_dir )
11221147
@@ -1234,7 +1259,19 @@ def main(args):
12341259 )
12351260
12361261 # Load scheduler and models
1237- noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
1262+ scheduler_type = determine_scheduler_type (args .pretrained_model_name_or_path , args .revision )
1263+ if "EDM" in scheduler_type :
1264+ args .do_edm_style_training = True
1265+ noise_scheduler = EDMEulerScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
1266+ logger .info ("Performing EDM-style training!" )
1267+ elif args .do_edm_style_training :
1268+ noise_scheduler = EulerDiscreteScheduler .from_pretrained (
1269+ args .pretrained_model_name_or_path , subfolder = "scheduler"
1270+ )
1271+ logger .info ("Performing EDM-style training!" )
1272+ else :
1273+ noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
1274+
12381275 text_encoder_one = text_encoder_cls_one .from_pretrained (
12391276 args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision , variant = args .variant
12401277 )
@@ -1252,7 +1289,12 @@ def main(args):
12521289 revision = args .revision ,
12531290 variant = args .variant ,
12541291 )
1255- vae_scaling_factor = vae .config .scaling_factor
1292+ latents_mean = latents_std = None
1293+ if hasattr (vae .config , "latents_mean" ) and vae .config .latents_mean is not None :
1294+ latents_mean = torch .tensor (vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
1295+ if hasattr (vae .config , "latents_std" ) and vae .config .latents_std is not None :
1296+ latents_std = torch .tensor (vae .config .latents_std ).view (1 , 4 , 1 , 1 )
1297+
12561298 unet = UNet2DConditionModel .from_pretrained (
12571299 args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
12581300 )
@@ -1790,6 +1832,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17901832 disable = not accelerator .is_local_main_process ,
17911833 )
17921834
1835+ def get_sigmas (timesteps , n_dim = 4 , dtype = torch .float32 ):
1836+ # TODO: revisit other sampling algorithms
1837+ sigmas = noise_scheduler .sigmas .to (device = accelerator .device , dtype = dtype )
1838+ schedule_timesteps = noise_scheduler .timesteps .to (accelerator .device )
1839+ timesteps = timesteps .to (accelerator .device )
1840+
1841+ step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
1842+
1843+ sigma = sigmas [step_indices ].flatten ()
1844+ while len (sigma .shape ) < n_dim :
1845+ sigma = sigma .unsqueeze (- 1 )
1846+ return sigma
1847+
17931848 if args .train_text_encoder :
17941849 num_train_epochs_text_encoder = int (args .train_text_encoder_frac * args .num_train_epochs )
17951850 elif args .train_text_encoder_ti : # args.train_text_encoder_ti
@@ -1841,9 +1896,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18411896 pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
18421897 model_input = vae .encode (pixel_values ).latent_dist .sample ()
18431898
1844- model_input = model_input * vae_scaling_factor
1845- if args .pretrained_vae_model_name_or_path is None :
1846- model_input = model_input .to (weight_dtype )
1899+ if latents_mean is None and latents_std is None :
1900+ model_input = model_input * vae .config .scaling_factor
1901+ if args .pretrained_vae_model_name_or_path is None :
1902+ model_input = model_input .to (weight_dtype )
1903+ else :
1904+ latents_mean = latents_mean .to (device = model_input .device , dtype = model_input .dtype )
1905+ latents_std = latents_std .to (device = model_input .device , dtype = model_input .dtype )
1906+ model_input = (model_input - latents_mean ) * vae .config .scaling_factor / latents_std
1907+ model_input = model_input .to (dtype = weight_dtype )
18471908
18481909 # Sample noise that we'll add to the latents
18491910 noise = torch .randn_like (model_input )
@@ -1854,15 +1915,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18541915 )
18551916
18561917 bsz = model_input .shape [0 ]
1918+
18571919 # Sample a random timestep for each image
1858- timesteps = torch .randint (
1859- 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = model_input .device
1860- )
1861- timesteps = timesteps .long ()
1920+ if not args .do_edm_style_training :
1921+ timesteps = torch .randint (
1922+ 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = model_input .device
1923+ )
1924+ timesteps = timesteps .long ()
1925+ else :
1926+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
1927+ # instead of discrete timesteps, so here we sample indices to get the noise levels
1928+ # from `scheduler.timesteps`
1929+ indices = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,))
1930+ timesteps = noise_scheduler .timesteps [indices ].to (device = model_input .device )
18621931
18631932 # Add noise to the model input according to the noise magnitude at each timestep
18641933 # (this is the forward diffusion process)
18651934 noisy_model_input = noise_scheduler .add_noise (model_input , noise , timesteps )
1935+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
1936+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
1937+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1938+ if args .do_edm_style_training :
1939+ sigmas = get_sigmas (timesteps , len (noisy_model_input .shape ), noisy_model_input .dtype )
1940+ if "EDM" in scheduler_type :
1941+ inp_noisy_latents = noise_scheduler .precondition_inputs (noisy_model_input , sigmas )
1942+ else :
1943+ inp_noisy_latents = noisy_model_input / ((sigmas ** 2 + 1 ) ** 0.5 )
18661944
18671945 # time ids
18681946 add_time_ids = torch .cat (
@@ -1888,7 +1966,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18881966 }
18891967 prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
18901968 model_pred = unet (
1891- noisy_model_input ,
1969+ inp_noisy_latents if args . do_edm_style_training else noisy_model_input ,
18921970 timesteps ,
18931971 prompt_embeds_input ,
18941972 added_cond_kwargs = unet_added_conditions ,
@@ -1906,14 +1984,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19061984 )
19071985 prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
19081986 model_pred = unet (
1909- noisy_model_input , timesteps , prompt_embeds_input , added_cond_kwargs = unet_added_conditions
1987+ inp_noisy_latents if args .do_edm_style_training else noisy_model_input ,
1988+ timesteps ,
1989+ prompt_embeds_input ,
1990+ added_cond_kwargs = unet_added_conditions ,
19101991 ).sample
19111992
1993+ weighting = None
1994+ if args .do_edm_style_training :
1995+ # Similar to the input preconditioning, the model predictions are also preconditioned
1996+ # on noised model inputs (before preconditioning) and the sigmas.
1997+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1998+ if "EDM" in scheduler_type :
1999+ model_pred = noise_scheduler .precondition_outputs (noisy_model_input , model_pred , sigmas )
2000+ else :
2001+ if noise_scheduler .config .prediction_type == "epsilon" :
2002+ model_pred = model_pred * (- sigmas ) + noisy_model_input
2003+ elif noise_scheduler .config .prediction_type == "v_prediction" :
2004+ model_pred = model_pred * (- sigmas / (sigmas ** 2 + 1 ) ** 0.5 ) + (
2005+ noisy_model_input / (sigmas ** 2 + 1 )
2006+ )
2007+ # We are not doing weighting here because it tends result in numerical problems.
2008+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
2009+ # There might be other alternatives for weighting as well:
2010+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
2011+ if "EDM" not in scheduler_type :
2012+ weighting = (sigmas ** - 2.0 ).float ()
2013+
19122014 # Get the target for loss depending on the prediction type
19132015 if noise_scheduler .config .prediction_type == "epsilon" :
1914- target = noise
2016+ target = model_input if args . do_edm_style_training else noise
19152017 elif noise_scheduler .config .prediction_type == "v_prediction" :
1916- target = noise_scheduler .get_velocity (model_input , noise , timesteps )
2018+ target = (
2019+ model_input
2020+ if args .do_edm_style_training
2021+ else noise_scheduler .get_velocity (model_input , noise , timesteps )
2022+ )
19172023 else :
19182024 raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
19192025
@@ -1923,10 +2029,28 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19232029 target , target_prior = torch .chunk (target , 2 , dim = 0 )
19242030
19252031 # Compute prior loss
1926- prior_loss = F .mse_loss (model_pred_prior .float (), target_prior .float (), reduction = "mean" )
2032+ if weighting is not None :
2033+ prior_loss = torch .mean (
2034+ (weighting .float () * (model_pred_prior .float () - target_prior .float ()) ** 2 ).reshape (
2035+ target_prior .shape [0 ], - 1
2036+ ),
2037+ 1 ,
2038+ )
2039+ prior_loss = prior_loss .mean ()
2040+ else :
2041+ prior_loss = F .mse_loss (model_pred_prior .float (), target_prior .float (), reduction = "mean" )
19272042
19282043 if args .snr_gamma is None :
1929- loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
2044+ if weighting is not None :
2045+ loss = torch .mean (
2046+ (weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (
2047+ target .shape [0 ], - 1
2048+ ),
2049+ 1 ,
2050+ )
2051+ loss = loss .mean ()
2052+ else :
2053+ loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
19302054 else :
19312055 # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
19322056 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
@@ -2049,26 +2173,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20492173 # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
20502174 scheduler_args = {}
20512175
2052- if "variance_type" in pipeline .scheduler .config :
2053- variance_type = pipeline .scheduler .config .variance_type
2176+ if not args .do_edm_style_training :
2177+ if "variance_type" in pipeline .scheduler .config :
2178+ variance_type = pipeline .scheduler .config .variance_type
20542179
2055- if variance_type in ["learned" , "learned_range" ]:
2056- variance_type = "fixed_small"
2180+ if variance_type in ["learned" , "learned_range" ]:
2181+ variance_type = "fixed_small"
20572182
2058- scheduler_args ["variance_type" ] = variance_type
2183+ scheduler_args ["variance_type" ] = variance_type
20592184
2060- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
2061- pipeline .scheduler .config , ** scheduler_args
2062- )
2185+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
2186+ pipeline .scheduler .config , ** scheduler_args
2187+ )
20632188
20642189 pipeline = pipeline .to (accelerator .device )
20652190 pipeline .set_progress_bar_config (disable = True )
20662191
20672192 # run inference
20682193 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
20692194 pipeline_args = {"prompt" : args .validation_prompt }
2195+ inference_ctx = (
2196+ contextlib .nullcontext ()
2197+ if "playground" in args .pretrained_model_name_or_path
2198+ else torch .cuda .amp .autocast ()
2199+ )
20702200
2071- with torch . cuda . amp . autocast () :
2201+ with inference_ctx :
20722202 images = [
20732203 pipeline (** pipeline_args , generator = generator ).images [0 ]
20742204 for _ in range (args .num_validation_images )
@@ -2144,15 +2274,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
21442274 # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
21452275 scheduler_args = {}
21462276
2147- if "variance_type" in pipeline .scheduler .config :
2148- variance_type = pipeline .scheduler .config .variance_type
2277+ if not args .do_edm_style_training :
2278+ if "variance_type" in pipeline .scheduler .config :
2279+ variance_type = pipeline .scheduler .config .variance_type
21492280
2150- if variance_type in ["learned" , "learned_range" ]:
2151- variance_type = "fixed_small"
2281+ if variance_type in ["learned" , "learned_range" ]:
2282+ variance_type = "fixed_small"
21522283
2153- scheduler_args ["variance_type" ] = variance_type
2284+ scheduler_args ["variance_type" ] = variance_type
21542285
2155- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
2286+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
2287+ pipeline .scheduler .config , ** scheduler_args
2288+ )
21562289
21572290 # load attention processors
21582291 pipeline .load_lora_weights (args .output_dir )
0 commit comments