Skip to content

Commit 139568a

Browse files
committed
Merge remote-tracking branch 'origin/hi-dream' into hi-dream
2 parents 5257b46 + dafc5fe commit 139568a

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def _encode_prompt_with_clip(
991991
if text_input_ids is None:
992992
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
993993

994-
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
994+
prompt_embeds = text_encoder(text_input_ids.to(device))
995995

996996
if hasattr(text_encoder, "module"):
997997
dtype = text_encoder.module.dtype
@@ -1216,7 +1216,7 @@ def main(args):
12161216
revision=args.revision,
12171217
variant=args.variant,
12181218
)
1219-
transformer = Lumina2Transformer2DModel.from_pretrained(
1219+
transformer = HiDreamImageTransformer2DModel.from_pretrained(
12201220
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
12211221
)
12221222

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
3333
# 1. get all state_dict_keys
3434
all_keys = list(state_dict.keys())
3535
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
36+
not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
37+
38+
# check if state_dict contains both patterns
39+
contains_sgm_patterns = False
40+
contains_not_sgm_patterns = False
41+
for key in all_keys:
42+
if any(p in key for p in sgm_patterns):
43+
contains_sgm_patterns = True
44+
elif any(p in key for p in not_sgm_patterns):
45+
contains_not_sgm_patterns = True
46+
47+
# if state_dict contains both patterns, remove sgm
48+
# we can then return state_dict immediately
49+
if contains_sgm_patterns and contains_not_sgm_patterns:
50+
for key in all_keys:
51+
if any(p in key for p in sgm_patterns):
52+
state_dict.pop(key)
53+
return state_dict
3654

3755
# 2. check if needs remapping, if not return original dict
3856
is_in_sgm_format = False
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
126144
)
127145
new_state_dict[new_key] = state_dict.pop(key)
128146

129-
if len(state_dict) > 0:
147+
if state_dict:
130148
raise ValueError("At this point all state dict entries have to be converted.")
131149

132150
return new_state_dict

0 commit comments

Comments
 (0)