Skip to content

Commit 5eef7f1

Browse files
authored
Merge branch 'main' into sd3.5_IPAdapter
2 parents 98f4521 + d8825e7 commit 5eef7f1

File tree

5 files changed

+152
-45
lines changed

5 files changed

+152
-45
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
convert_ldm_vae_checkpoint,
3333
convert_ltx_transformer_checkpoint_to_diffusers,
3434
convert_ltx_vae_checkpoint_to_diffusers,
35+
convert_mochi_transformer_checkpoint_to_diffusers,
3536
convert_sd3_transformer_checkpoint_to_diffusers,
3637
convert_stable_cascade_unet_single_file_to_diffusers,
3738
create_controlnet_diffusers_config_from_ldm,
@@ -96,6 +97,10 @@
9697
"default_subfolder": "vae",
9798
},
9899
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
100+
"MochiTransformer3DModel": {
101+
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
102+
"default_subfolder": "transformer",
103+
},
99104
}
100105

101106

src/diffusers/loaders/single_file_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
],
107107
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
108108
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
109+
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
109110
}
110111

111112
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -159,6 +160,7 @@
159160
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
160161
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
161162
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
163+
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
162164
}
163165

164166
# Use to configure model sample size when original config is provided
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
618620
else:
619621
model_type = "autoencoder-dc-f128c512"
620622

623+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
624+
model_type = "mochi-1-preview"
625+
621626
else:
622627
model_type = "v1"
623628

@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
17581763
return new_weight
17591764

17601765

1766+
def swap_proj_gate(weight):
1767+
proj, gate = weight.chunk(2, dim=0)
1768+
new_weight = torch.cat([gate, proj], dim=0)
1769+
return new_weight
1770+
1771+
17611772
def get_attn2_layers(state_dict):
17621773
attn2_layers = []
17631774
for key in state_dict.keys():
@@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict):
24142425
handler_fn_inplace(key, converted_state_dict)
24152426

24162427
return converted_state_dict
2428+
2429+
2430+
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2431+
new_state_dict = {}
2432+
2433+
# Comfy checkpoints add this prefix
2434+
keys = list(checkpoint.keys())
2435+
for k in keys:
2436+
if "model.diffusion_model." in k:
2437+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2438+
2439+
# Convert patch_embed
2440+
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2441+
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2442+
2443+
# Convert time_embed
2444+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2445+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2446+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2447+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2448+
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2449+
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2450+
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2451+
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2452+
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2453+
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2454+
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2455+
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2456+
2457+
# Convert transformer blocks
2458+
num_layers = 48
2459+
for i in range(num_layers):
2460+
block_prefix = f"transformer_blocks.{i}."
2461+
old_prefix = f"blocks.{i}."
2462+
2463+
# norm1
2464+
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2465+
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2466+
if i < num_layers - 1:
2467+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2468+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2469+
else:
2470+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2471+
old_prefix + "mod_y.weight"
2472+
)
2473+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2474+
2475+
# Visual attention
2476+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2477+
q, k, v = qkv_weight.chunk(3, dim=0)
2478+
2479+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2480+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2481+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2482+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2483+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2484+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2485+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2486+
2487+
# Context attention
2488+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2489+
q, k, v = qkv_weight.chunk(3, dim=0)
2490+
2491+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2492+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2493+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2494+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2495+
old_prefix + "attn.q_norm_y.weight"
2496+
)
2497+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2498+
old_prefix + "attn.k_norm_y.weight"
2499+
)
2500+
if i < num_layers - 1:
2501+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2502+
old_prefix + "attn.proj_y.weight"
2503+
)
2504+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2505+
2506+
# MLP
2507+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2508+
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2509+
)
2510+
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2511+
if i < num_layers - 1:
2512+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2513+
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2514+
)
2515+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2516+
2517+
# Output layers
2518+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2519+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2520+
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2521+
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2522+
2523+
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2524+
2525+
return new_state_dict

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23+
from ...loaders.single_file_model import FromOriginalModelMixin
2324
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention import FeedForward
@@ -304,7 +305,7 @@ def forward(
304305

305306

306307
@maybe_allow_in_graph
307-
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
308+
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
308309
r"""
309310
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
310311
@@ -334,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
334335
"""
335336

336337
_supports_gradient_checkpointing = True
338+
_no_split_modules = ["MochiTransformerBlock"]
337339

338340
@register_to_config
339341
def __init__(

src/diffusers/utils/hub_utils.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
455455
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
456456

457457
ignore_patterns = ["*.json", "*.md"]
458-
if not local_files_only:
459-
# `model_info` call must guarded with the above condition.
460-
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
461-
for shard_file in original_shard_filenames:
462-
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
463-
if not shard_file_present:
464-
raise EnvironmentError(
465-
f"{shards_path} does not appear to have a file named {shard_file} which is "
466-
"required according to the checkpoint index."
467-
)
468-
469-
try:
470-
# Load from URL
471-
cached_folder = snapshot_download(
472-
pretrained_model_name_or_path,
473-
cache_dir=cache_dir,
474-
proxies=proxies,
475-
local_files_only=local_files_only,
476-
token=token,
477-
revision=revision,
478-
allow_patterns=allow_patterns,
479-
ignore_patterns=ignore_patterns,
480-
user_agent=user_agent,
481-
)
482-
if subfolder is not None:
483-
cached_folder = os.path.join(cached_folder, subfolder)
484-
485-
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
486-
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
487-
except HTTPError as e:
458+
# `model_info` call must guarded with the above condition.
459+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
460+
for shard_file in original_shard_filenames:
461+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
462+
if not shard_file_present:
488463
raise EnvironmentError(
489-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
490-
" again after checking your internet connection."
491-
) from e
464+
f"{shards_path} does not appear to have a file named {shard_file} which is "
465+
"required according to the checkpoint index."
466+
)
492467

493-
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
494-
elif local_files_only:
495-
_check_if_shards_exist_locally(
496-
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
468+
try:
469+
# Load from URL
470+
cached_folder = snapshot_download(
471+
pretrained_model_name_or_path,
472+
cache_dir=cache_dir,
473+
proxies=proxies,
474+
local_files_only=local_files_only,
475+
token=token,
476+
revision=revision,
477+
allow_patterns=allow_patterns,
478+
ignore_patterns=ignore_patterns,
479+
user_agent=user_agent,
497480
)
498481
if subfolder is not None:
499-
cached_folder = os.path.join(cache_dir, subfolder)
482+
cached_folder = os.path.join(cached_folder, subfolder)
483+
484+
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
485+
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
486+
except HTTPError as e:
487+
raise EnvironmentError(
488+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
489+
" again after checking your internet connection."
490+
) from e
500491

501492
return cached_folder, sharded_metadata
502493

tests/lora/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests:
8989

9090
has_two_text_encoders = False
9191
has_three_text_encoders = False
92-
text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None
93-
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None
94-
text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None
95-
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None
96-
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None
97-
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None
92+
text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
93+
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
94+
text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
95+
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
96+
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
97+
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
9898

9999
unet_kwargs = None
100100
transformer_cls = None

0 commit comments

Comments
 (0)