Skip to content

Commit 3178c4e

Browse files
committed
add backwardability
1 parent 42c1451 commit 3178c4e

File tree

7 files changed

+97
-15
lines changed

7 files changed

+97
-15
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
2626
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
2727
"time_projection.1": "condition_embedder.time_proj",
28-
"head.modulation": "norm_out.linear.weight",
28+
"head.modulation": "scale_shift_table",
2929
"head.head": "proj_out",
3030
"modulation": "scale_shift_table",
3131
"ffn.0": "ffn.net.0.proj",
@@ -67,7 +67,7 @@
6767
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
6868
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
6969
"time_projection.1": "condition_embedder.time_proj",
70-
"head.modulation": "norm_out.linear.weight",
70+
"head.modulation": "scale_shift_table",
7171
"head.head": "proj_out",
7272
"modulation": "scale_shift_table",
7373
"ffn.0": "ffn.net.0.proj",
@@ -105,12 +105,8 @@
105105
"after_proj": "proj_out",
106106
}
107107

108-
TRANSFORMER_SPECIAL_KEYS_REMAP = {
109-
"norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0]))
110-
}
111-
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
112-
"norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0]))
113-
}
108+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
109+
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
114110

115111

116112
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -312,10 +308,6 @@ def convert_transformer(model_type: str):
312308
continue
313309
handler_fn_inplace(key, original_state_dict)
314310

315-
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
316-
if special_key not in original_state_dict:
317-
handler_fn_inplace(special_key, original_state_dict)
318-
319311
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
320312
return transformer
321313

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ def __init__(
171171

172172
self.gradient_checkpointing = False
173173

174+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
175+
if "scale_shift_table" in state_dict:
176+
scale_shift_table = state_dict.pop("scale_shift_table")
177+
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
178+
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
179+
return super()._load_from_state_dict(
180+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
181+
)
182+
174183
def forward(
175184
self,
176185
hidden_states: torch.Tensor,

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,18 @@ def __init__(
185185
)
186186
self.caption_projection = None
187187
if self.config.caption_channels is not None:
188-
self.caption_projection = PixArtAlphaTextProjection(
189-
in_features=self.config.caption_channels, hidden_size=self.inner_dim
190-
)
188+
self.caption_projection = PixArtAlphaTextProjection(
189+
in_features=self.config.caption_channels, hidden_size=self.inner_dim
190+
)
191+
192+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
193+
if "scale_shift_table" in state_dict:
194+
scale_shift_table = state_dict.pop("scale_shift_table")
195+
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
196+
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
197+
return super()._load_from_state_dict(
198+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
199+
)
191200

192201
@property
193202
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,15 @@ def __init__(
310310

311311
self.gradient_checkpointing = False
312312

313+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
314+
if "scale_shift_table" in state_dict:
315+
scale_shift_table = state_dict.pop("scale_shift_table")
316+
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
317+
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
318+
return super()._load_from_state_dict(
319+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
320+
)
321+
313322
def forward(
314323
self,
315324
hidden_states: torch.Tensor,

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,27 @@ def __init__(
400400

401401
self.gradient_checkpointing = False
402402

403+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
404+
key = "scale_shift_table"
405+
if prefix + key in state_dict:
406+
scale_shift_table = state_dict.pop(prefix + key)
407+
inner_dim = scale_shift_table.shape[-1]
408+
409+
weight = torch.eye(inner_dim).repeat(2, 1)
410+
bias = scale_shift_table.reshape(2, inner_dim).flatten()
411+
412+
state_dict[prefix + "norm_out.linear.weight"] = weight
413+
state_dict[prefix + "norm_out.linear.bias"] = bias
414+
415+
if prefix + "norm_out.weight" in state_dict:
416+
state_dict.pop(prefix + "norm_out.weight")
417+
if prefix + "norm_out.bias" in state_dict:
418+
state_dict.pop(prefix + "norm_out.bias")
419+
420+
return super(LTXVideoTransformer3DModel, self)._load_from_state_dict(
421+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
422+
)
423+
403424
def forward(
404425
self,
405426
hidden_states: torch.Tensor,

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,27 @@ def __init__(
439439

440440
self.gradient_checkpointing = False
441441

442+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
443+
key = "scale_shift_table"
444+
if prefix + key in state_dict:
445+
scale_shift_table = state_dict.pop(prefix + key)
446+
inner_dim = scale_shift_table.shape[-1]
447+
448+
weight = torch.eye(inner_dim).repeat(2, 1)
449+
bias = scale_shift_table.reshape(2, inner_dim).flatten()
450+
451+
state_dict[prefix + "norm_out.linear.weight"] = weight
452+
state_dict[prefix + "norm_out.linear.bias"] = bias
453+
454+
if prefix + "norm_out.weight" in state_dict:
455+
state_dict.pop(prefix + "norm_out.weight")
456+
if prefix + "norm_out.bias" in state_dict:
457+
state_dict.pop(prefix + "norm_out.bias")
458+
459+
return super(WanTransformer3DModel, self)._load_from_state_dict(
460+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
461+
)
462+
442463
def forward(
443464
self,
444465
hidden_states: torch.Tensor,

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,27 @@ def __init__(
270270

271271
self.gradient_checkpointing = False
272272

273+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
274+
key = "scale_shift_table"
275+
if prefix + key in state_dict:
276+
scale_shift_table = state_dict.pop(prefix + key)
277+
inner_dim = scale_shift_table.shape[-1]
278+
279+
weight = torch.eye(inner_dim).repeat(2, 1)
280+
bias = scale_shift_table.reshape(2, inner_dim).flatten()
281+
282+
state_dict[prefix + "norm_out.linear.weight"] = weight
283+
state_dict[prefix + "norm_out.linear.bias"] = bias
284+
285+
if prefix + "norm_out.weight" in state_dict:
286+
state_dict.pop(prefix + "norm_out.weight")
287+
if prefix + "norm_out.bias" in state_dict:
288+
state_dict.pop(prefix + "norm_out.bias")
289+
290+
return super(WanVACETransformer3DModel, self)._load_from_state_dict(
291+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
292+
)
293+
273294
def forward(
274295
self,
275296
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)