2929import numpy as np
3030import torch
3131import transformers
32- from accelerate import Accelerator
32+ from accelerate import Accelerator , DistributedType
3333from accelerate .logging import get_logger
34+ from accelerate .state import AcceleratorState
3435from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3536from huggingface_hub import create_repo , upload_folder
3637from huggingface_hub .utils import insecure_hashlib
@@ -1222,6 +1223,9 @@ def main(args):
12221223 kwargs_handlers = [kwargs ],
12231224 )
12241225
1226+ if accelerator .distributed_type == DistributedType .DEEPSPEED :
1227+ AcceleratorState ().deepspeed_plugin .deepspeed_config ["train_micro_batch_size_per_gpu" ] = args .train_batch_size
1228+
12251229 # Disable AMP for MPS.
12261230 if torch .backends .mps .is_available ():
12271231 accelerator .native_amp = False
@@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir):
14381442 text_encoder_one_lora_layers_to_save = None
14391443 modules_to_save = {}
14401444 for model in models :
1441- if isinstance (model , type (unwrap_model (transformer ))):
1445+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1446+ model = unwrap_model (model )
14421447 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
14431448 modules_to_save ["transformer" ] = model
1444- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1449+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1450+ model = unwrap_model (model )
14451451 text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
14461452 modules_to_save ["text_encoder" ] = model
14471453 else :
14481454 raise ValueError (f"unexpected save model: { model .__class__ } " )
14491455
14501456 # make sure to pop weight so that corresponding model is not saved again
1451- weights .pop ()
1457+ if weights :
1458+ weights .pop ()
14521459
14531460 FluxKontextPipeline .save_lora_weights (
14541461 output_dir ,
@@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir):
14611468 transformer_ = None
14621469 text_encoder_one_ = None
14631470
1464- while len (models ) > 0 :
1465- model = models .pop ()
1471+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1472+ while len (models ) > 0 :
1473+ model = models .pop ()
14661474
1467- if isinstance (model , type (unwrap_model (transformer ))):
1468- transformer_ = model
1469- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1470- text_encoder_one_ = model
1471- else :
1472- raise ValueError (f"unexpected save model: { model .__class__ } " )
1475+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1476+ transformer_ = unwrap_model (model )
1477+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1478+ text_encoder_one_ = unwrap_model (model )
1479+ else :
1480+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1481+
1482+ else :
1483+ transformer_ = FluxTransformer2DModel .from_pretrained (
1484+ args .pretrained_model_name_or_path , subfolder = "transformer"
1485+ )
1486+ transformer_ .add_adapter (transformer_lora_config )
1487+ text_encoder_one_ = text_encoder_cls_one .from_pretrained (
1488+ args .pretrained_model_name_or_path , subfolder = "text_encoder"
1489+ )
14731490
14741491 lora_state_dict = FluxKontextPipeline .lora_state_dict (input_dir )
14751492
@@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20692086 progress_bar .update (1 )
20702087 global_step += 1
20712088
2072- if accelerator .is_main_process :
2089+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
20732090 if global_step % args .checkpointing_steps == 0 :
20742091 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
20752092 if args .checkpoints_total_limit is not None :
0 commit comments