2929import  torch 
3030import  torch .utils .checkpoint 
3131import  transformers 
32- from  accelerate  import  Accelerator 
32+ from  accelerate  import  Accelerator ,  DistributedType 
3333from  accelerate .logging  import  get_logger 
3434from  accelerate .utils  import  DistributedDataParallelKwargs , ProjectConfiguration , set_seed 
3535from  huggingface_hub  import  create_repo , upload_folder 
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
12921292            text_encoder_two_lora_layers_to_save  =  None 
12931293
12941294            for  model  in  models :
1295-                 if  isinstance (model , type (unwrap_model (transformer ))):
1295+                 if  isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1296+                     model  =  unwrap_model (model )
1297+                     if  args .upcast_before_saving :
1298+                         model  =  model .to (torch .float32 )
12961299                    transformer_lora_layers_to_save  =  get_peft_model_state_dict (model )
1297-                 elif  isinstance (model , type (unwrap_model (text_encoder_one ))):  # or text_encoder_two 
1300+                 elif  args .train_text_encoder  and  isinstance (
1301+                     unwrap_model (model ), type (unwrap_model (text_encoder_one ))
1302+                 ):  # or text_encoder_two 
12981303                    # both text encoders are of the same class, so we check hidden size to distinguish between the two 
1299-                     hidden_size  =  unwrap_model (model ).config .hidden_size 
1304+                     model  =  unwrap_model (model )
1305+                     hidden_size  =  model .config .hidden_size 
13001306                    if  hidden_size  ==  768 :
13011307                        text_encoder_one_lora_layers_to_save  =  get_peft_model_state_dict (model )
13021308                    elif  hidden_size  ==  1280 :
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
13051311                    raise  ValueError (f"unexpected save model: { model .__class__ }  )
13061312
13071313                # make sure to pop weight so that corresponding model is not saved again 
1308-                 weights .pop ()
1314+                 if  weights :
1315+                     weights .pop ()
13091316
13101317            StableDiffusion3Pipeline .save_lora_weights (
13111318                output_dir ,
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
13191326        text_encoder_one_  =  None 
13201327        text_encoder_two_  =  None 
13211328
1322-         while  len (models ) >  0 :
1323-             model  =  models .pop ()
1329+         if  not  accelerator .distributed_type  ==  DistributedType .DEEPSPEED :
1330+             while  len (models ) >  0 :
1331+                 model  =  models .pop ()
13241332
1325-             if  isinstance (model , type (unwrap_model (transformer ))):
1326-                 transformer_  =  model 
1327-             elif  isinstance (model , type (unwrap_model (text_encoder_one ))):
1328-                 text_encoder_one_  =  model 
1329-             elif  isinstance (model , type (unwrap_model (text_encoder_two ))):
1330-                 text_encoder_two_  =  model 
1331-             else :
1332-                 raise  ValueError (f"unexpected save model: { model .__class__ }  )
1333+                 if  isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1334+                     transformer_  =  unwrap_model (model )
1335+                 elif  isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1336+                     text_encoder_one_  =  unwrap_model (model )
1337+                 elif  isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_two ))):
1338+                     text_encoder_two_  =  unwrap_model (model )
1339+                 else :
1340+                     raise  ValueError (f"unexpected save model: { model .__class__ }  )
1341+ 
1342+         else :
1343+             transformer_  =  SD3Transformer2DModel .from_pretrained (
1344+                 args .pretrained_model_name_or_path , subfolder = "transformer" 
1345+             )
1346+             transformer_ .add_adapter (transformer_lora_config )
1347+             if  args .train_text_encoder :
1348+                 text_encoder_one_  =  text_encoder_cls_one .from_pretrained (
1349+                     args .pretrained_model_name_or_path , subfolder = "text_encoder" 
1350+                 )
1351+                 text_encoder_two_  =  text_encoder_cls_two .from_pretrained (
1352+                     args .pretrained_model_name_or_path , subfolder = "text_encoder_2" 
1353+                 )
13331354
13341355        lora_state_dict  =  StableDiffusion3Pipeline .lora_state_dict (input_dir )
13351356
@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18291850                progress_bar .update (1 )
18301851                global_step  +=  1 
18311852
1832-                 if  accelerator .is_main_process :
1853+                 if  accelerator .is_main_process   or   accelerator . distributed_type   ==   DistributedType . DEEPSPEED :
18331854                    if  global_step  %  args .checkpointing_steps  ==  0 :
18341855                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 
18351856                        if  args .checkpoints_total_limit  is  not None :
0 commit comments