2929import  numpy  as  np 
3030import  torch 
3131import  transformers 
32- from  accelerate  import  Accelerator 
32+ from  accelerate  import  Accelerator ,  DistributedType 
3333from  accelerate .logging  import  get_logger 
34+ from  accelerate .state  import  AcceleratorState 
3435from  accelerate .utils  import  DistributedDataParallelKwargs , ProjectConfiguration , set_seed 
3536from  huggingface_hub  import  create_repo , upload_folder 
3637from  huggingface_hub .utils  import  insecure_hashlib 
@@ -1222,6 +1223,9 @@ def main(args):
12221223        kwargs_handlers = [kwargs ],
12231224    )
12241225
1226+     if  accelerator .distributed_type  ==  DistributedType .DEEPSPEED :
1227+         AcceleratorState ().deepspeed_plugin .deepspeed_config ["train_micro_batch_size_per_gpu" ] =  args .train_batch_size 
1228+ 
12251229    # Disable AMP for MPS. 
12261230    if  torch .backends .mps .is_available ():
12271231        accelerator .native_amp  =  False 
@@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir):
14381442            text_encoder_one_lora_layers_to_save  =  None 
14391443            modules_to_save  =  {}
14401444            for  model  in  models :
1441-                 if  isinstance (model , type (unwrap_model (transformer ))):
1445+                 if  isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1446+                     model  =  unwrap_model (model )
14421447                    transformer_lora_layers_to_save  =  get_peft_model_state_dict (model )
14431448                    modules_to_save ["transformer" ] =  model 
1444-                 elif  isinstance (model , type (unwrap_model (text_encoder_one ))):
1449+                 elif  isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1450+                     model  =  unwrap_model (model )
14451451                    text_encoder_one_lora_layers_to_save  =  get_peft_model_state_dict (model )
14461452                    modules_to_save ["text_encoder" ] =  model 
14471453                else :
14481454                    raise  ValueError (f"unexpected save model: { model .__class__ }  )
14491455
14501456                # make sure to pop weight so that corresponding model is not saved again 
1451-                 weights .pop ()
1457+                 if  weights :
1458+                     weights .pop ()
14521459
14531460            FluxKontextPipeline .save_lora_weights (
14541461                output_dir ,
@@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir):
14611468        transformer_  =  None 
14621469        text_encoder_one_  =  None 
14631470
1464-         while  len (models ) >  0 :
1465-             model  =  models .pop ()
1471+         if  not  accelerator .distributed_type  ==  DistributedType .DEEPSPEED :
1472+             while  len (models ) >  0 :
1473+                 model  =  models .pop ()
14661474
1467-             if  isinstance (model , type (unwrap_model (transformer ))):
1468-                 transformer_  =  model 
1469-             elif  isinstance (model , type (unwrap_model (text_encoder_one ))):
1470-                 text_encoder_one_  =  model 
1471-             else :
1472-                 raise  ValueError (f"unexpected save model: { model .__class__ }  )
1475+                 if  isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1476+                     transformer_  =  unwrap_model (model )
1477+                 elif  isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1478+                     text_encoder_one_  =  unwrap_model (model )
1479+                 else :
1480+                     raise  ValueError (f"unexpected save model: { model .__class__ }  )
1481+ 
1482+         else :
1483+             transformer_  =  FluxTransformer2DModel .from_pretrained (
1484+                 args .pretrained_model_name_or_path , subfolder = "transformer" 
1485+             )
1486+             transformer_ .add_adapter (transformer_lora_config )
1487+             text_encoder_one_  =  text_encoder_cls_one .from_pretrained (
1488+                 args .pretrained_model_name_or_path , subfolder = "text_encoder" 
1489+             )
14731490
14741491        lora_state_dict  =  FluxKontextPipeline .lora_state_dict (input_dir )
14751492
@@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20692086                progress_bar .update (1 )
20702087                global_step  +=  1 
20712088
2072-                 if  accelerator .is_main_process :
2089+                 if  accelerator .is_main_process   or   accelerator . distributed_type   ==   DistributedType . DEEPSPEED :
20732090                    if  global_step  %  args .checkpointing_steps  ==  0 :
20742091                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 
20752092                        if  args .checkpoints_total_limit  is  not None :
0 commit comments