-
Couldn't load subscription status.
- Fork 6.5k
[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); #9708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 94 commits
6e616a9
d2e187a
90e8939
825c975
3a44fa4
55b2615
6fb7fdb
c323e76
da7caa5
fb6d92a
5e63a1a
72cce2b
8f9b4e4
b7f68f9
6d96b95
3c3cc51
1448681
bf40fe8
dd7718a
19986a5
3481e23
0e818df
c6eb233
59de0a3
ea604a4
80dce02
1752afd
883bcf4
25ae389
96e844b
59b6e25
7ce9ff2
30d6308
cab56b1
b42bb54
2e04a99
b4f75f2
c82f828
22ea5fd
4f5cbb4
2f6bbad
4495783
4d3c026
e007057
d3d9c84
be9826c
20da201
5ed50e9
2d59056
c1c02a2
1f8a3b3
7b9d7e5
bf6c211
a2ec5f8
f5876c5
44034a6
6379241
77571a8
c4d0867
0bdb7ef
54e933b
babc9f5
3d5faaf
65edfa5
ca3ac4d
9ef7b59
074817c
64de66a
0bda5c5
eb64d52
4a224ce
30c3238
39a947c
68f817a
da834d5
632ad3b
d6c748c
46eb504
31f9fc6
6f29e2a
b6e8fba
f862bae
f9fce24
e594745
3c0b1ca
91057d4
67aa715
eda66e1
e3d33e6
cc97502
2b370df
94355ab
a191f07
116c049
b6e0aba
ec4e84f
dbae8f1
042c2a0
f2525b9
d3d224c
6122b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| <!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. --> | ||
|
|
||
| # AutoencoderDC | ||
|
|
||
| *The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.* | ||
|
|
||
| The following DCAE models are released and supported in Diffusers: | ||
|
|
||
| | diffusers format | original format | | ||
a-r-r-o-w marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| |:----------------:|:---------------:| | ||
| | [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0) | ||
| | [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0) | ||
| | [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0) | ||
| | [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0) | ||
| | [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0) | ||
| | [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0) | ||
| | [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0) | ||
|
|
||
| The models can be loaded with the following code snippet. | ||
a-r-r-o-w marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```python | ||
| from diffusers import AutoencoderDC | ||
|
|
||
| ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda") | ||
|
||
| ``` | ||
|
|
||
| ## Single file loading | ||
|
|
||
| The `AutoencoderDC` implementation supports loading checkpoints shipped in the original format by MIT HAN Lab. The following example demonstrates how to load the `f128c512` checkpoint: | ||
a-r-r-o-w marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```python | ||
| from diffusers import AutoencoderDC | ||
|
|
||
| model_name = "dc-ae-f128c512-mix-1.0" | ||
| ae = AutoencoderDC.from_single_file( | ||
| f"https://huggingface.co/mit-han-lab/{model_name}/model.safetensors", | ||
| original_config=f"https://huggingface.co/mit-han-lab/{model_name}/resolve/main/config.json" | ||
| ) | ||
| ``` | ||
|
|
||
| ## AutoencoderDC | ||
|
|
||
| [[autodoc]] AutoencoderDC | ||
| - decode | ||
| - all | ||
|
|
||
| ## DecoderOutput | ||
|
|
||
| [[autodoc]] models.autoencoders.vae.DecoderOutput | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,323 @@ | ||
| import argparse | ||
| from typing import Any, Dict | ||
|
|
||
| import torch | ||
| from huggingface_hub import hf_hub_download | ||
| from safetensors.torch import load_file | ||
|
|
||
| from diffusers import AutoencoderDC | ||
|
|
||
|
|
||
| def remap_qkv_(key: str, state_dict: Dict[str, Any]): | ||
| qkv = state_dict.pop(key) | ||
| q, k, v = torch.chunk(qkv, 3, dim=0) | ||
| parent_module, _, _ = key.rpartition(".qkv.conv.weight") | ||
| state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() | ||
| state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() | ||
| state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() | ||
|
|
||
|
|
||
| def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): | ||
| parent_module, _, _ = key.rpartition(".proj.conv.weight") | ||
| state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() | ||
|
|
||
|
|
||
| AE_KEYS_RENAME_DICT = { | ||
| # common | ||
| "main.": "", | ||
| "op_list.": "", | ||
| "context_module": "attn", | ||
| "local_module": "conv_out", | ||
| # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 | ||
| # If there were more scales, there would be more layers, so a loop would be better to handle this | ||
| "aggreg.0.0": "to_qkv_multiscale.0.proj_in", | ||
| "aggreg.0.1": "to_qkv_multiscale.0.proj_out", | ||
| "depth_conv.conv": "conv_depth", | ||
| "inverted_conv.conv": "conv_inverted", | ||
| "point_conv.conv": "conv_point", | ||
| "point_conv.norm": "norm", | ||
| "conv.conv.": "conv.", | ||
| "conv1.conv": "conv1", | ||
| "conv2.conv": "conv2", | ||
| "conv2.norm": "norm", | ||
| "proj.norm": "norm_out", | ||
| # encoder | ||
| "encoder.project_in.conv": "encoder.conv_in", | ||
| "encoder.project_out.0.conv": "encoder.conv_out", | ||
| "encoder.stages": "encoder.down_blocks", | ||
| # decoder | ||
| "decoder.project_in.conv": "decoder.conv_in", | ||
| "decoder.project_out.0": "decoder.norm_out", | ||
| "decoder.project_out.2.conv": "decoder.conv_out", | ||
| "decoder.stages": "decoder.up_blocks", | ||
| } | ||
|
|
||
| AE_F32C32_KEYS = { | ||
| # encoder | ||
| "encoder.project_in.conv": "encoder.conv_in.conv", | ||
| # decoder | ||
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
| } | ||
|
|
||
| AE_F64C128_KEYS = { | ||
| # encoder | ||
| "encoder.project_in.conv": "encoder.conv_in.conv", | ||
| # decoder | ||
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
| } | ||
|
|
||
| AE_F128C512_KEYS = { | ||
| # encoder | ||
| "encoder.project_in.conv": "encoder.conv_in.conv", | ||
| # decoder | ||
| "decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
| } | ||
|
|
||
| AE_SPECIAL_KEYS_REMAP = { | ||
| "qkv.conv.weight": remap_qkv_, | ||
| "proj.conv.weight": remap_proj_conv_, | ||
| } | ||
|
|
||
|
|
||
| def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
| state_dict = saved_dict | ||
| if "model" in saved_dict.keys(): | ||
| state_dict = state_dict["model"] | ||
| if "module" in saved_dict.keys(): | ||
| state_dict = state_dict["module"] | ||
| if "state_dict" in saved_dict.keys(): | ||
| state_dict = state_dict["state_dict"] | ||
| return state_dict | ||
|
|
||
|
|
||
| def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
||
|
|
||
| def convert_ae(config_name: str, dtype: torch.dtype): | ||
| config = get_ae_config(config_name) | ||
| hub_id = f"mit-han-lab/{config_name}" | ||
| ckpt_path = hf_hub_download(hub_id, "model.safetensors") | ||
| original_state_dict = get_state_dict(load_file(ckpt_path)) | ||
|
|
||
| ae = AutoencoderDC(**config).to(dtype=dtype) | ||
|
|
||
| for key in list(original_state_dict.keys()): | ||
| new_key = key[:] | ||
| for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): | ||
| new_key = new_key.replace(replace_key, rename_key) | ||
| update_state_dict_(original_state_dict, key, new_key) | ||
|
|
||
| for key in list(original_state_dict.keys()): | ||
| for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): | ||
| if special_key not in key: | ||
| continue | ||
| handler_fn_inplace(key, original_state_dict) | ||
|
|
||
| ae.load_state_dict(original_state_dict, strict=True) | ||
| return ae | ||
|
|
||
|
|
||
| def get_ae_config(name: str): | ||
| if name in ["dc-ae-f32c32-sana-1.0"]: | ||
| config = { | ||
| "latent_channels": 32, | ||
| "encoder_block_types": ( | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ), | ||
| "decoder_block_types": ( | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ), | ||
| "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
| "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
| "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
| "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
| "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), | ||
| "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], | ||
| "downsample_block_type": "conv", | ||
| "upsample_block_type": "interpolate", | ||
| "decoder_norm_types": "rms_norm", | ||
| "decoder_act_fns": "silu", | ||
| "scaling_factor": 0.41407, | ||
| } | ||
| elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: | ||
| AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) | ||
| config = { | ||
| "latent_channels": 32, | ||
| "encoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "decoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], | ||
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], | ||
| "encoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
| "decoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
| "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], | ||
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], | ||
| } | ||
| if name == "dc-ae-f32c32-in-1.0": | ||
| config["scaling_factor"] = 0.3189 | ||
| elif name == "dc-ae-f32c32-mix-1.0": | ||
| config["scaling_factor"] = 0.4552 | ||
| elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: | ||
| AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) | ||
| config = { | ||
| "latent_channels": 128, | ||
| "encoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "decoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], | ||
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], | ||
| "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
| "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
| "decoder_norm_types": [ | ||
| "batch_norm", | ||
| "batch_norm", | ||
| "batch_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| ], | ||
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], | ||
| } | ||
| if name == "dc-ae-f64c128-in-1.0": | ||
| config["scaling_factor"] = 0.2889 | ||
| elif name == "dc-ae-f64c128-mix-1.0": | ||
| config["scaling_factor"] = 0.4538 | ||
| elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: | ||
| AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) | ||
| config = { | ||
| "latent_channels": 512, | ||
| "encoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "decoder_block_types": [ | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "ResBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| "EfficientViTBlock", | ||
| ], | ||
| "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
| "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
| "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], | ||
| "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], | ||
| "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
| "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
| "decoder_norm_types": [ | ||
| "batch_norm", | ||
| "batch_norm", | ||
| "batch_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| "rms_norm", | ||
| ], | ||
| "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], | ||
| } | ||
| if name == "dc-ae-f128c512-in-1.0": | ||
| config["scaling_factor"] = 0.4883 | ||
| elif name == "dc-ae-f128c512-mix-1.0": | ||
| config["scaling_factor"] = 0.3620 | ||
| else: | ||
| raise ValueError("Invalid config name provided.") | ||
|
|
||
| return config | ||
|
|
||
|
|
||
| def get_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--config_name", | ||
| type=str, | ||
| default="dc-ae-f32c32-sana-1.0", | ||
| choices=[ | ||
| "dc-ae-f32c32-sana-1.0", | ||
| "dc-ae-f32c32-in-1.0", | ||
| "dc-ae-f32c32-mix-1.0", | ||
| "dc-ae-f64c128-in-1.0", | ||
| "dc-ae-f64c128-mix-1.0", | ||
| "dc-ae-f128c512-in-1.0", | ||
| "dc-ae-f128c512-mix-1.0", | ||
| ], | ||
| help="The DCAE checkpoint to convert", | ||
| ) | ||
| parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
| parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| DTYPE_MAPPING = { | ||
| "fp32": torch.float32, | ||
| "fp16": torch.float16, | ||
| "bf16": torch.bfloat16, | ||
| } | ||
|
|
||
| VARIANT_MAPPING = { | ||
| "fp32": None, | ||
| "fp16": "fp16", | ||
| "bf16": "bf16", | ||
| } | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = get_args() | ||
|
|
||
| dtype = DTYPE_MAPPING[args.dtype] | ||
| variant = VARIANT_MAPPING[args.dtype] | ||
|
|
||
| ae = convert_ae(args.config_name, dtype) | ||
| ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) |
Uh oh!
There was an error while loading. Please reload this page.