From 46b6de96f03442e016062718a470fad114d6cf75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 18 Jun 2025 11:02:40 +0800 Subject: [PATCH 1/2] [training] add ds support to lora hidream --- .../train_dreambooth_lora_hidream.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index f368fb809e73..2ec582eeb6b1 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -29,7 +29,7 @@ import numpy as np import torch import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1181,13 +1181,15 @@ def save_model_hook(models, weights, output_dir): transformer_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() HiDreamImagePipeline.save_lora_weights( output_dir, @@ -1197,13 +1199,21 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = HiDreamImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir) @@ -1655,7 +1665,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: From d5bbe9591df202e69a428b16c54f87a37682dc60 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 18 Jun 2025 03:22:58 +0000 Subject: [PATCH 2/2] Apply style fixes --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 2ec582eeb6b1..a1337e8dbaa4 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1210,8 +1210,7 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") else: transformer_ = HiDreamImageTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer" + args.pretrained_model_name_or_path, subfolder="transformer" ) transformer_.add_adapter(transformer_lora_config)