Skip to content

Commit ca4b38c

Browse files
committed
update
1 parent f09f51c commit ca4b38c

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@
8686
},
8787
"LTXTransformer3DModel": {
8888
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
89+
"default_subfolder": "transformer",
8990
},
9091
"AutoencoderKLLTX": {
9192
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
93+
"default_subfolder": "vae",
9294
},
9395
}
9496

@@ -278,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
278280
subfolder=subfolder,
279281
local_files_only=local_files_only,
280282
token=token,
283+
revision=revision,
281284
)
282285
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
283286

src/diffusers/loaders/single_file_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@
9393
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
9494
],
9595
"ltx-video": [
96-
("patchify_proj.weight", "transformer_blocks.27.scale_shift_table"),
96+
(
97+
"model.diffusion_model.patchify_proj.weight",
98+
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
99+
),
97100
],
98101
}
99102

@@ -2211,14 +2214,20 @@ def swap_scale_shift(weight):
22112214
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
22122215
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
22132216

2217+
def remove_keys_(key: str, state_dict):
2218+
state_dict.pop(key)
2219+
22142220
TRANSFORMER_KEYS_RENAME_DICT = {
2221+
"model.diffusion_model.": "",
22152222
"patchify_proj": "proj_in",
22162223
"adaln_single": "time_embed",
22172224
"q_norm": "norm_q",
22182225
"k_norm": "norm_k",
22192226
}
22202227

2221-
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
2228+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
2229+
"vae": remove_keys_,
2230+
}
22222231

22232232
for key in list(converted_state_dict.keys()):
22242233
new_key = key
@@ -2242,6 +2251,8 @@ def remove_keys_(key: str, state_dict):
22422251
state_dict.pop(key)
22432252

22442253
VAE_KEYS_RENAME_DICT = {
2254+
# common
2255+
"vae.": "",
22452256
# decoder
22462257
"up_blocks.0": "mid_block",
22472258
"up_blocks.1": "up_blocks.0",
@@ -2276,6 +2287,7 @@ def remove_keys_(key: str, state_dict):
22762287
"per_channel_statistics.channel": remove_keys_,
22772288
"per_channel_statistics.mean-of-means": remove_keys_,
22782289
"per_channel_statistics.mean-of-stds": remove_keys_,
2290+
"model.diffusion_model": remove_keys_,
22792291
}
22802292

22812293
for key in list(converted_state_dict.keys()):

0 commit comments

Comments
 (0)