Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
6e616a9
first add a script for DC-AE;
lawrence-cj Oct 18, 2024
d2e187a
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 23, 2024
90e8939
DC-AE init
chenjy2003 Oct 23, 2024
825c975
replace triton with custom implementation
chenjy2003 Oct 23, 2024
3a44fa4
1. rename file and remove un-used codes;
lawrence-cj Oct 23, 2024
55b2615
no longer rely on omegaconf and dataclass
chenjy2003 Oct 25, 2024
6fb7fdb
merge
chenjy2003 Oct 25, 2024
c323e76
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 25, 2024
da7caa5
replace custom activation with diffuers activation
chenjy2003 Oct 25, 2024
fb6d92a
remove dc_ae attention in attention_processor.py
chenjy2003 Oct 25, 2024
5e63a1a
iinherit from ModelMixin
chenjy2003 Oct 25, 2024
72cce2b
inherit from ConfigMixin
chenjy2003 Oct 25, 2024
8f9b4e4
dc-ae reduce to one file
chenjy2003 Oct 31, 2024
b7f68f9
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 31, 2024
6d96b95
Merge branch 'huggingface:main' into DC-AE
lawrence-cj Nov 4, 2024
3c3cc51
Merge remote-tracking branch 'refs/remotes/origin/main' into DC-AE
lawrence-cj Nov 6, 2024
1448681
update downsample and upsample
chenjy2003 Nov 9, 2024
bf40fe8
merge
chenjy2003 Nov 9, 2024
dd7718a
clean code
chenjy2003 Nov 9, 2024
19986a5
support DecoderOutput
chenjy2003 Nov 9, 2024
3481e23
Merge branch 'main' into DC-AE
lawrence-cj Nov 9, 2024
0e818df
Merge branch 'main' into DC-AE
lawrence-cj Nov 13, 2024
c6eb233
remove get_same_padding and val2tuple
chenjy2003 Nov 14, 2024
59de0a3
remove autocast and some assert
chenjy2003 Nov 14, 2024
ea604a4
update ResBlock
chenjy2003 Nov 14, 2024
80dce02
remove contents within super().__init__
chenjy2003 Nov 14, 2024
1752afd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 16, 2024
883bcf4
remove opsequential
chenjy2003 Nov 16, 2024
25ae389
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
96e844b
update other blocks to support the removal of build_norm
chenjy2003 Nov 16, 2024
59b6e25
Merge branch 'main' into DC-AE
sayakpaul Nov 16, 2024
7ce9ff2
remove build encoder/decoder project in/out
chenjy2003 Nov 16, 2024
30d6308
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
cab56b1
remove inheritance of RMSNorm2d from LayerNorm
chenjy2003 Nov 16, 2024
b42bb54
remove reset_parameters for RMSNorm2d
chenjy2003 Nov 20, 2024
2e04a99
remove device and dtype in RMSNorm2d __init__
chenjy2003 Nov 20, 2024
b4f75f2
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
c82f828
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
22ea5fd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
4f5cbb4
remove op_list & build_block
chenjy2003 Nov 26, 2024
2f6bbad
remove build_stage_main
chenjy2003 Nov 26, 2024
4495783
Merge branch 'main' into DC-AE
lawrence-cj Nov 26, 2024
4d3c026
change file name to autoencoder_dc
chenjy2003 Nov 28, 2024
e007057
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 28, 2024
d3d9c84
move LiteMLA to attention.py
chenjy2003 Nov 28, 2024
5ed50e9
update
a-r-r-o-w Nov 28, 2024
c1c02a2
quick push before dgx disappears again
a-r-r-o-w Nov 28, 2024
1f8a3b3
update
a-r-r-o-w Nov 28, 2024
7b9d7e5
make style
a-r-r-o-w Nov 28, 2024
bf6c211
update
a-r-r-o-w Nov 28, 2024
a2ec5f8
update
a-r-r-o-w Nov 28, 2024
f5876c5
fix
a-r-r-o-w Nov 28, 2024
44034a6
refactor
a-r-r-o-w Nov 29, 2024
6379241
refactor
a-r-r-o-w Nov 29, 2024
77571a8
refactor
a-r-r-o-w Nov 29, 2024
c4d0867
update
a-r-r-o-w Nov 30, 2024
0bdb7ef
possibly change to nn.Linear
a-r-r-o-w Nov 30, 2024
54e933b
refactor
a-r-r-o-w Nov 30, 2024
babc9f5
Merge branch 'main' into aryan-dcae
a-r-r-o-w Nov 30, 2024
3d5faaf
make fix-copies
a-r-r-o-w Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions scripts/convert_dcae_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import argparse
from typing import Any, Dict

import torch
from safetensors.torch import load_file

from diffusers import AutoencoderDC


def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)


def remap_qkv_(key: str, state_dict: Dict[str, Any]):
# qkv = state_dict.pop(key)
# q, k, v = torch.chunk(qkv, 3, dim=0)
# parent_module, _, _ = key.rpartition(".qkv.conv.weight")
# state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
# state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
# state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
state_dict[key.replace("qkv.conv", "to_qkv")] = state_dict.pop(key)


VAE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"norm.": "norm.norm.",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.conv": "proj_out",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out.norm",
"decoder.project_out.2.conv": "decoder.conv_out",
}

VAE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def convert_vae(ckpt_path: str, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderDC(
in_channels=3,
latent_channels=32,
encoder_block_types=(
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
decoder_block_types=(
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
block_out_channels=(128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block=(2, 2, 2, 3, 3, 3),
decoder_layers_per_block=(3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales=((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales=((), (), (), (5,), (5,), (5,)),
downsample_block_type="Conv",
upsample_block_type="interpolate",
decoder_norm_types="rms_norm",
decoder_act_fns="silu",
scaling_factor=0.41407,
).to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True)
return vae


def get_vae_config(name: str):
if name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": ("ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"),
"decoder_block_types": ("ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"),
"block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "Conv",
"upsample_block_type": "interpolate",
"scaling_factor": 0.41407,
}
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}

return config


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}


if __name__ == "__main__":
args = get_args()

dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]

if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, dtype)
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
Expand Down Expand Up @@ -571,6 +572,7 @@
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
Expand Down Expand Up @@ -88,6 +89,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
from .normalization import (
AdaLayerNorm,
AdaLayerNormContinuous,
AdaLayerNormZero,
RMSNorm,
SD35AdaLayerNormZeroX,
)


logger = logging.get_logger(__name__)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
Expand Down
Loading
Loading