|
25 | 25 | "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
26 | 26 | "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
27 | 27 | "time_projection.1": "condition_embedder.time_proj", |
28 | | - "head.modulation": "norm_out.linear.weight", |
| 28 | + "head.modulation": "scale_shift_table", |
29 | 29 | "head.head": "proj_out", |
30 | 30 | "modulation": "scale_shift_table", |
31 | 31 | "ffn.0": "ffn.net.0.proj", |
|
67 | 67 | "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
68 | 68 | "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
69 | 69 | "time_projection.1": "condition_embedder.time_proj", |
70 | | - "head.modulation": "norm_out.linear.weight", |
| 70 | + "head.modulation": "scale_shift_table", |
71 | 71 | "head.head": "proj_out", |
72 | 72 | "modulation": "scale_shift_table", |
73 | 73 | "ffn.0": "ffn.net.0.proj", |
|
105 | 105 | "after_proj": "proj_out", |
106 | 106 | } |
107 | 107 |
|
108 | | -TRANSFORMER_SPECIAL_KEYS_REMAP = { |
109 | | - "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) |
110 | | -} |
111 | | -VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = { |
112 | | - "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) |
113 | | -} |
| 108 | +TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
| 109 | +VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
114 | 110 |
|
115 | 111 |
|
116 | 112 | def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
@@ -312,10 +308,6 @@ def convert_transformer(model_type: str): |
312 | 308 | continue |
313 | 309 | handler_fn_inplace(key, original_state_dict) |
314 | 310 |
|
315 | | - for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): |
316 | | - if special_key not in original_state_dict: |
317 | | - handler_fn_inplace(special_key, original_state_dict) |
318 | | - |
319 | 311 | transformer.load_state_dict(original_state_dict, strict=True, assign=True) |
320 | 312 | return transformer |
321 | 313 |
|
|
0 commit comments