2525import  accelerate 
2626import  numpy  as  np 
2727import  torch 
28- import  torch .nn .functional  as  F 
2928import  torch .utils .checkpoint 
3029import  transformers 
3130from  accelerate  import  Accelerator 
4342import  diffusers 
4443from  diffusers  import  AutoencoderKL , FlowMatchEulerDiscreteScheduler , FluxControlPipeline , FluxTransformer2DModel 
4544from  diffusers .optimization  import  get_scheduler 
46- from  diffusers .training_utils  import  cast_training_params , compute_density_for_timestep_sampling , free_memory 
45+ from  diffusers .training_utils  import  (
46+     cast_training_params ,
47+     compute_density_for_timestep_sampling ,
48+     compute_loss_weighting_for_sd3 ,
49+     free_memory ,
50+ )
4751from  diffusers .utils  import  check_min_version , is_wandb_available , make_image_grid 
4852from  diffusers .utils .hub_utils  import  load_or_create_model_card , populate_model_card 
4953from  diffusers .utils .torch_utils  import  is_compiled_module 
@@ -550,7 +554,7 @@ def parse_args(input_args=None):
550554    parser .add_argument (
551555        "--weighting_scheme" ,
552556        type = str ,
553-         default = "logit_normal " ,
557+         default = "none " ,
554558        choices = ["sigma_sqrt" , "logit_normal" , "mode" , "cosmap" , "none" ],
555559        help = ('We default to the "none" weighting scheme for uniform sampling and uniform loss' ),
556560    )
@@ -566,11 +570,6 @@ def parse_args(input_args=None):
566570        default = 1.29 ,
567571        help = "Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`." ,
568572    )
569-     parser .add_argument (
570-         "--enable_model_cpu_offload" ,
571-         action = "store_true" ,
572-         help = "Enable model cpu offload and save memory." ,
573-     )
574573
575574    if  input_args  is  not None :
576575        args  =  parser .parse_args (input_args )
@@ -672,7 +671,8 @@ def prepare_train_dataset(dataset, accelerator):
672671        [
673672            transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
674673            transforms .CenterCrop (args .resolution ),
675-             transforms .Lambda (lambda  x : x  /  127.5  -  1.0 ),
674+             transforms .ToTensor (),
675+             transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ]),
676676        ]
677677    )
678678
@@ -735,7 +735,7 @@ def main(args):
735735
736736    # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. 
737737    if  torch .backends .mps .is_available ():
738-         print ("MPS is enabled. Disabling AMP." )
738+         logger . info ("MPS is enabled. Disabling AMP." )
739739        accelerator .native_amp  =  False 
740740
741741    # Make one log on every process with the configuration for debugging. 
@@ -776,6 +776,7 @@ def main(args):
776776        revision = args .revision ,
777777        variant = args .variant ,
778778    )
779+     vae_scale_factor  =  2  **  (len (vae .config .block_out_channels ) -  1 )
779780    flux_transformer  =  FluxTransformer2DModel .from_pretrained (
780781        args .pretrained_model_name_or_path ,
781782        subfolder = "transformer" ,
@@ -817,6 +818,8 @@ def main(args):
817818        new_linear .weight [:, :initial_input_channels ].copy_ (flux_transformer .x_embedder .weight )
818819        new_linear .bias .copy_ (flux_transformer .x_embedder .bias )
819820        flux_transformer .x_embedder  =  new_linear 
821+ 
822+     assert  torch .all (new_linear .weight [:, initial_input_channels :].data  ==  0 )
820823    flux_transformer .register_to_config (in_channels = initial_input_channels  *  2 )
821824
822825    if  args .lora_layers  is  not None :
@@ -1092,24 +1095,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10921095                # offload vae to CPU. 
10931096                vae .cpu ()
10941097
1095-                 # pack the latents. 
1096-                 packed_pixel_latents  =  FluxControlPipeline ._pack_latents (
1097-                     pixel_latents ,
1098-                     batch_size = pixel_latents .shape [0 ],
1099-                     num_channels_latents = pixel_latents .shape [1 ],
1100-                     height = pixel_latents .shape [2 ],
1101-                     width = pixel_latents .shape [3 ],
1102-                 )
1103-                 packed_control_latents  =  FluxControlPipeline ._pack_latents (
1104-                     pixel_latents ,
1105-                     batch_size = control_latents .shape [0 ],
1106-                     num_channels_latents = control_latents .shape [1 ],
1107-                     height = control_latents .shape [2 ],
1108-                     width = control_latents .shape [3 ],
1109-                 )
1110-                 # concate across channels. 
1111-                 latent_model_input  =  torch .cat ([packed_pixel_latents , packed_control_latents ], dim = 2 )
1112- 
11131098                # Sample a random timestep for each image 
11141099                # for weighting schemes where we sample timesteps non-uniformly 
11151100                bsz  =  pixel_latents .shape [0 ]
@@ -1122,25 +1107,37 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11221107                    mode_scale = args .mode_scale ,
11231108                )
11241109                indices  =  (u  *  noise_scheduler_copy .config .num_train_timesteps ).long ()
1125-                 timesteps  =  noise_scheduler_copy .timesteps [indices ].to (device = latent_model_input .device )
1110+                 timesteps  =  noise_scheduler_copy .timesteps [indices ].to (device = pixel_latents .device )
11261111
11271112                # Add noise according to flow matching. 
1128-                 sigmas  =  get_sigmas (timesteps , n_dim = latent_model_input .ndim , dtype = latent_model_input .dtype )
1129-                 noisy_model_input  =  (1.0  -  sigmas ) *  latent_model_input  +  sigmas  *  noise 
1113+                 sigmas  =  get_sigmas (timesteps , n_dim = pixel_latents .ndim , dtype = pixel_latents .dtype )
1114+                 noisy_model_input  =  (1.0  -  sigmas ) *  pixel_latents  +  sigmas  *  noise 
1115+                 # Concatenate across channels. 
1116+                 # Question: Should we concatenate before adding noise? 
1117+                 concatenated_noisy_model_input  =  torch .cat ([noisy_model_input , control_latents ], dim = 1 )
1118+ 
1119+                 # pack the latents. 
1120+                 packed_noisy_model_input  =  FluxControlPipeline ._pack_latents (
1121+                     concatenated_noisy_model_input ,
1122+                     batch_size = bsz ,
1123+                     num_channels_latents = concatenated_noisy_model_input .shape [1 ],
1124+                     height = concatenated_noisy_model_input .shape [2 ],
1125+                     width = concatenated_noisy_model_input .shape [3 ],
1126+                 )
11301127
11311128                # latent image ids for RoPE. 
11321129                latent_image_ids  =  FluxControlPipeline ._prepare_latent_image_ids (
1133-                     pixel_latents . shape [ 0 ] ,
1134-                     pixel_latents .shape [2 ] //  2 ,
1135-                     pixel_latents .shape [3 ] //  2 ,
1130+                     bsz ,
1131+                     concatenated_noisy_model_input .shape [2 ] //  2 ,
1132+                     concatenated_noisy_model_input .shape [3 ] //  2 ,
11361133                    accelerator .device ,
11371134                    weight_dtype ,
11381135                )
11391136
11401137                # handle guidance 
11411138                if  flux_transformer .config .guidance_embeds :
11421139                    guidance_vec  =  torch .full (
1143-                         (noisy_model_input . shape [ 0 ] ,),
1140+                         (bsz ,),
11441141                        args .guidance_scale ,
11451142                        device = noisy_model_input .device ,
11461143                        dtype = weight_dtype ,
@@ -1152,12 +1149,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11521149                captions  =  batch ["captions" ]
11531150                text_encoding_pipeline  =  text_encoding_pipeline .to ("cuda" )
11541151                with  torch .no_grad ():
1155-                     prompt_embeds , pooled_prompt_embeds , text_ids  =  text_encoding_pipeline .encode_prompt (captions )
1152+                     prompt_embeds , pooled_prompt_embeds , text_ids  =  text_encoding_pipeline .encode_prompt (
1153+                         captions , prompt_2 = None 
1154+                     )
11561155                text_encoding_pipeline  =  text_encoding_pipeline .to ("cuda" )
11571156
11581157                # Predict. 
1159-                 noise_pred  =  flux_transformer (
1160-                     hidden_states = noisy_model_input ,
1158+                 model_pred  =  flux_transformer (
1159+                     hidden_states = packed_noisy_model_input ,
11611160                    timestep = timesteps  /  1000 ,
11621161                    guidance = guidance_vec ,
11631162                    pooled_projections = pooled_prompt_embeds ,
@@ -1166,10 +1165,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11661165                    img_ids = latent_image_ids ,
11671166                    return_dict = False ,
11681167                )[0 ]
1169- 
1170-                 loss  =  F .mse_loss (noise_pred .float (), (noise  -  pixel_latents ).float (), reduction = "mean" )
1168+                 model_pred  =  FluxControlPipeline ._unpack_latents (
1169+                     model_pred ,
1170+                     height = noisy_model_input .shape [2 ] *  vae_scale_factor ,
1171+                     width = noisy_model_input .shape [3 ] *  vae_scale_factor ,
1172+                     vae_scale_factor = vae_scale_factor ,
1173+                 )
1174+                 # these weighting schemes use a uniform timestep sampling 
1175+                 # and instead post-weight the loss 
1176+                 weighting  =  compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
1177+ 
1178+                 # flow-matching loss 
1179+                 target  =  noise  -  pixel_latents 
1180+                 loss  =  torch .mean (
1181+                     (weighting .float () *  (model_pred .float () -  target .float ()) **  2 ).reshape (target .shape [0 ], - 1 ),
1182+                     1 ,
1183+                 )
1184+                 loss  =  loss .mean ()
11711185                accelerator .backward (loss )
1172-                 # Check if the gradient of each model parameter contains NaN 
11731186
11741187                if  accelerator .sync_gradients :
11751188                    params_to_clip  =  flux_transformer .parameters ()
0 commit comments