| 
92 | 92 |         "double_blocks.0.img_attn.norm.key_norm.scale",  | 
93 | 93 |         "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",  | 
94 | 94 |     ],  | 
 | 95 | +    "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",  | 
 | 96 | +    "autoencoder-dc-sana": "encoder.project_in.conv.bias",  | 
95 | 97 | }  | 
96 | 98 | 
 
  | 
97 | 99 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {  | 
 | 
138 | 140 |     "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},  | 
139 | 141 |     "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},  | 
140 | 142 |     "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},  | 
 | 143 | +    "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},  | 
 | 144 | +    "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},  | 
 | 145 | +    "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},  | 
 | 146 | +    "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},  | 
141 | 147 | }  | 
142 | 148 | 
 
  | 
143 | 149 | # Use to configure model sample size when original config is provided  | 
@@ -564,6 +570,23 @@ def infer_diffusers_model_type(checkpoint):  | 
564 | 570 |             model_type = "flux-dev"  | 
565 | 571 |         else:  | 
566 | 572 |             model_type = "flux-schnell"  | 
 | 573 | + | 
 | 574 | +    elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:  | 
 | 575 | +        encoder_key = "encoder.project_in.conv.conv.bias"  | 
 | 576 | +        decoder_key = "decoder.project_in.main.conv.weight"  | 
 | 577 | + | 
 | 578 | +        if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:  | 
 | 579 | +            model_type = "autoencoder-dc-f32c32-sana"  | 
 | 580 | + | 
 | 581 | +        elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:  | 
 | 582 | +            model_type = "autoencoder-dc-f32c32"  | 
 | 583 | + | 
 | 584 | +        elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:  | 
 | 585 | +            model_type = "autoencoder-dc-f64c128"  | 
 | 586 | + | 
 | 587 | +        else:  | 
 | 588 | +            model_type = "autoencoder-dc-f128c512"  | 
 | 589 | + | 
567 | 590 |     else:  | 
568 | 591 |         model_type = "v1"  | 
569 | 592 | 
 
  | 
@@ -2198,3 +2221,75 @@ def swap_scale_shift(weight):  | 
2198 | 2221 |     )  | 
2199 | 2222 | 
 
  | 
2200 | 2223 |     return converted_state_dict  | 
 | 2224 | + | 
 | 2225 | + | 
 | 2226 | +def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):  | 
 | 2227 | +    converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}  | 
 | 2228 | + | 
 | 2229 | +    def remap_qkv_(key: str, state_dict):  | 
 | 2230 | +        qkv = state_dict.pop(key)  | 
 | 2231 | +        q, k, v = torch.chunk(qkv, 3, dim=0)  | 
 | 2232 | +        parent_module, _, _ = key.rpartition(".qkv.conv.weight")  | 
 | 2233 | +        state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()  | 
 | 2234 | +        state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()  | 
 | 2235 | +        state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()  | 
 | 2236 | + | 
 | 2237 | +    def remap_proj_conv_(key: str, state_dict):  | 
 | 2238 | +        parent_module, _, _ = key.rpartition(".proj.conv.weight")  | 
 | 2239 | +        state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()  | 
 | 2240 | + | 
 | 2241 | +    AE_KEYS_RENAME_DICT = {  | 
 | 2242 | +        # common  | 
 | 2243 | +        "main.": "",  | 
 | 2244 | +        "op_list.": "",  | 
 | 2245 | +        "context_module": "attn",  | 
 | 2246 | +        "local_module": "conv_out",  | 
 | 2247 | +        # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1  | 
 | 2248 | +        # If there were more scales, there would be more layers, so a loop would be better to handle this  | 
 | 2249 | +        "aggreg.0.0": "to_qkv_multiscale.0.proj_in",  | 
 | 2250 | +        "aggreg.0.1": "to_qkv_multiscale.0.proj_out",  | 
 | 2251 | +        "depth_conv.conv": "conv_depth",  | 
 | 2252 | +        "inverted_conv.conv": "conv_inverted",  | 
 | 2253 | +        "point_conv.conv": "conv_point",  | 
 | 2254 | +        "point_conv.norm": "norm",  | 
 | 2255 | +        "conv.conv.": "conv.",  | 
 | 2256 | +        "conv1.conv": "conv1",  | 
 | 2257 | +        "conv2.conv": "conv2",  | 
 | 2258 | +        "conv2.norm": "norm",  | 
 | 2259 | +        "proj.norm": "norm_out",  | 
 | 2260 | +        # encoder  | 
 | 2261 | +        "encoder.project_in.conv": "encoder.conv_in",  | 
 | 2262 | +        "encoder.project_out.0.conv": "encoder.conv_out",  | 
 | 2263 | +        "encoder.stages": "encoder.down_blocks",  | 
 | 2264 | +        # decoder  | 
 | 2265 | +        "decoder.project_in.conv": "decoder.conv_in",  | 
 | 2266 | +        "decoder.project_out.0": "decoder.norm_out",  | 
 | 2267 | +        "decoder.project_out.2.conv": "decoder.conv_out",  | 
 | 2268 | +        "decoder.stages": "decoder.up_blocks",  | 
 | 2269 | +    }  | 
 | 2270 | + | 
 | 2271 | +    AE_F32C32_F64C128_F128C512_KEYS = {  | 
 | 2272 | +        "encoder.project_in.conv": "encoder.conv_in.conv",  | 
 | 2273 | +        "decoder.project_out.2.conv": "decoder.conv_out.conv",  | 
 | 2274 | +    }  | 
 | 2275 | + | 
 | 2276 | +    AE_SPECIAL_KEYS_REMAP = {  | 
 | 2277 | +        "qkv.conv.weight": remap_qkv_,  | 
 | 2278 | +        "proj.conv.weight": remap_proj_conv_,  | 
 | 2279 | +    }  | 
 | 2280 | +    if "encoder.project_in.conv.bias" not in converted_state_dict:  | 
 | 2281 | +        AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)  | 
 | 2282 | + | 
 | 2283 | +    for key in list(converted_state_dict.keys()):  | 
 | 2284 | +        new_key = key[:]  | 
 | 2285 | +        for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():  | 
 | 2286 | +            new_key = new_key.replace(replace_key, rename_key)  | 
 | 2287 | +        converted_state_dict[new_key] = converted_state_dict.pop(key)  | 
 | 2288 | + | 
 | 2289 | +    for key in list(converted_state_dict.keys()):  | 
 | 2290 | +        for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():  | 
 | 2291 | +            if special_key not in key:  | 
 | 2292 | +                continue  | 
 | 2293 | +            handler_fn_inplace(key, converted_state_dict)  | 
 | 2294 | + | 
 | 2295 | +    return converted_state_dict  | 
0 commit comments