Skip to content

Commit 5ed50e9

Browse files
committed
update
1 parent d3d9c84 commit 5ed50e9

File tree

7 files changed

+837
-481
lines changed

7 files changed

+837
-481
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import argparse
2+
from typing import Any, Dict
3+
4+
import torch
5+
from safetensors.torch import load_file
6+
from transformers import T5EncoderModel, T5Tokenizer
7+
8+
from diffusers import AutoencoderDC
9+
10+
11+
def remove_keys_(key: str, state_dict: Dict[str, Any]):
12+
state_dict.pop(key)
13+
14+
15+
TOKENIZER_MAX_LENGTH = 128
16+
17+
TRANSFORMER_KEYS_RENAME_DICT = {}
18+
19+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
20+
21+
VAE_KEYS_RENAME_DICT = {
22+
# common
23+
"norm.": "norm.norm.",
24+
# encoder
25+
"encoder.project_in": "encoder.conv_in",
26+
"encoder.project_out.main.op_list.0": "encoder.conv_out",
27+
# decoder
28+
"decoder.project_in.main": "decoder.conv_in",
29+
"decoder.project_out.op_list.0": "decoder.norm_out.norm",
30+
"decoder.project_out.op_list.2": "decoder.conv_out",
31+
}
32+
33+
VAE_SPECIAL_KEYS_REMAP = {}
34+
35+
36+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
37+
state_dict = saved_dict
38+
if "model" in saved_dict.keys():
39+
state_dict = state_dict["model"]
40+
if "module" in saved_dict.keys():
41+
state_dict = state_dict["module"]
42+
if "state_dict" in saved_dict.keys():
43+
state_dict = state_dict["state_dict"]
44+
return state_dict
45+
46+
47+
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
48+
state_dict[new_key] = state_dict.pop(old_key)
49+
50+
51+
# def convert_transformer(
52+
# ckpt_path: str,
53+
# dtype: torch.dtype,
54+
# ):
55+
# PREFIX_KEY = ""
56+
57+
# original_state_dict = get_state_dict(load_file(ckpt_path))
58+
# transformer = LTXTransformer3DModel().to(dtype=dtype)
59+
60+
# for key in list(original_state_dict.keys()):
61+
# new_key = key[len(PREFIX_KEY) :]
62+
# for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
63+
# new_key = new_key.replace(replace_key, rename_key)
64+
# update_state_dict_inplace(original_state_dict, key, new_key)
65+
66+
# for key in list(original_state_dict.keys()):
67+
# for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
68+
# if special_key not in key:
69+
# continue
70+
# handler_fn_inplace(key, original_state_dict)
71+
72+
# transformer.load_state_dict(original_state_dict, strict=True)
73+
# return transformer
74+
75+
76+
def convert_vae(ckpt_path: str, dtype: torch.dtype):
77+
original_state_dict = get_state_dict(load_file(ckpt_path))
78+
vae = AutoencoderDC(
79+
in_channels=3,
80+
latent_channels=32,
81+
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
82+
encoder_depth_list=[2, 2, 2, 3, 3, 3],
83+
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
84+
encoder_norm="rms2d",
85+
encoder_act="silu",
86+
downsample_block_type="Conv",
87+
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
88+
decoder_depth_list=[3, 3, 3, 3, 3, 3],
89+
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
90+
decoder_norm="rms2d",
91+
decoder_act="silu",
92+
upsample_block_type="InterpolateConv",
93+
scaling_factor=0.41407,
94+
).to(dtype=dtype)
95+
96+
for key in list(original_state_dict.keys()):
97+
new_key = key[:]
98+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
99+
new_key = new_key.replace(replace_key, rename_key)
100+
update_state_dict_inplace(original_state_dict, key, new_key)
101+
102+
for key in list(original_state_dict.keys()):
103+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
104+
if special_key not in key:
105+
continue
106+
handler_fn_inplace(key, original_state_dict)
107+
108+
vae.load_state_dict(original_state_dict, strict=True)
109+
return vae
110+
111+
112+
def get_args():
113+
parser = argparse.ArgumentParser()
114+
parser.add_argument(
115+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
116+
)
117+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
118+
parser.add_argument(
119+
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
120+
)
121+
parser.add_argument(
122+
"--typecast_text_encoder",
123+
action="store_true",
124+
default=False,
125+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
126+
)
127+
parser.add_argument("--save_pipeline", action="store_true")
128+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
129+
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
130+
return parser.parse_args()
131+
132+
133+
DTYPE_MAPPING = {
134+
"fp32": torch.float32,
135+
"fp16": torch.float16,
136+
"bf16": torch.bfloat16,
137+
}
138+
139+
VARIANT_MAPPING = {
140+
"fp32": None,
141+
"fp16": "fp16",
142+
"bf16": "bf16",
143+
}
144+
145+
146+
if __name__ == "__main__":
147+
args = get_args()
148+
149+
transformer = None
150+
dtype = DTYPE_MAPPING[args.dtype]
151+
variant = VARIANT_MAPPING[args.dtype]
152+
153+
if args.save_pipeline:
154+
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
155+
156+
# if args.transformer_ckpt_path is not None:
157+
# transformer = convert_transformer(args.transformer_ckpt_path, dtype)
158+
# if not args.save_pipeline:
159+
# transformer.save_pretrained(
160+
# args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
161+
# )
162+
163+
if args.vae_ckpt_path is not None:
164+
vae = convert_vae(args.vae_ckpt_path, dtype)
165+
if not args.save_pipeline:
166+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
"AllegroTransformer3DModel",
8181
"AsymmetricAutoencoderKL",
8282
"AuraFlowTransformer2DModel",
83+
"AutoencoderDC",
8384
"AutoencoderKL",
8485
"AutoencoderKLAllegro",
8586
"AutoencoderKLCogVideoX",
@@ -572,6 +573,7 @@
572573
AsymmetricAutoencoderKL,
573574
AuraFlowTransformer2DModel,
574575
AutoencoderKL,
576+
AutoencoderDC,
575577
AutoencoderKLAllegro,
576578
AutoencoderKLCogVideoX,
577579
AutoencoderKLMochi,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
if is_torch_available():
2828
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
2929
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30+
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
3031
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
3132
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
3233
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
@@ -88,6 +89,7 @@
8889
from .adapter import MultiAdapter, T2IAdapter
8990
from .autoencoders import (
9091
AsymmetricAutoencoderKL,
92+
AutoencoderDC,
9193
AutoencoderKL,
9294
AutoencoderKLAllegro,
9395
AutoencoderKLCogVideoX,

src/diffusers/models/attention.py

Lines changed: 1 addition & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, get_activation
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNorm2d
25+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNormNd
2626

2727

2828
logger = logging.get_logger(__name__)
@@ -1241,160 +1241,3 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12411241
for module in self.net:
12421242
hidden_states = module(hidden_states)
12431243
return hidden_states
1244-
1245-
1246-
class DCAELiteMLA(nn.Module):
1247-
r"""Lightweight multi-scale linear attention used in DC-AE"""
1248-
1249-
def __init__(
1250-
self,
1251-
in_channels: int,
1252-
out_channels: int,
1253-
heads: Optional[int] = None,
1254-
heads_ratio: float = 1.0,
1255-
dim=8,
1256-
use_bias=(False, False),
1257-
norm=(None, "bn2d"),
1258-
act_func=(None, None),
1259-
kernel_func="relu",
1260-
scales: Tuple[int, ...] = (5,),
1261-
eps=1.0e-15,
1262-
):
1263-
super().__init__()
1264-
self.eps = eps
1265-
heads = int(in_channels // dim * heads_ratio) if heads is None else heads
1266-
1267-
total_dim = heads * dim
1268-
1269-
self.dim = dim
1270-
1271-
qkv = [nn.Conv2d(in_channels=in_channels, out_channels=3 * total_dim, kernel_size=1, bias=use_bias[0])]
1272-
if norm[0] is None:
1273-
pass
1274-
elif norm[0] == "rms2d":
1275-
qkv.append(RMSNorm2d(num_features=3 * total_dim))
1276-
elif norm[0] == "bn2d":
1277-
qkv.append(nn.BatchNorm2d(num_features=3 * total_dim))
1278-
else:
1279-
raise ValueError(f"norm {norm[0]} is not supported")
1280-
if act_func[0] is not None:
1281-
qkv.append(get_activation(act_func[0]))
1282-
self.qkv = nn.Sequential(*qkv)
1283-
1284-
self.aggreg = nn.ModuleList(
1285-
[
1286-
nn.Sequential(
1287-
nn.Conv2d(
1288-
3 * total_dim,
1289-
3 * total_dim,
1290-
scale,
1291-
padding=scale // 2,
1292-
groups=3 * total_dim,
1293-
bias=use_bias[0],
1294-
),
1295-
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
1296-
)
1297-
for scale in scales
1298-
]
1299-
)
1300-
self.kernel_func = get_activation(kernel_func)
1301-
1302-
proj = [nn.Conv2d(in_channels=total_dim * (1 + len(scales)), out_channels=out_channels, kernel_size=1, bias=use_bias[1])]
1303-
if norm[1] is None:
1304-
pass
1305-
elif norm[1] == "rms2d":
1306-
proj.append(RMSNorm2d(num_features=out_channels))
1307-
elif norm[1] == "bn2d":
1308-
proj.append(nn.BatchNorm2d(num_features=out_channels))
1309-
else:
1310-
raise ValueError(f"norm {norm[1]} is not supported")
1311-
if act_func[1] is not None:
1312-
proj.append(get_activation(act_func[1]))
1313-
self.proj = nn.Sequential(*proj)
1314-
1315-
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
1316-
B, _, H, W = list(qkv.size())
1317-
1318-
if qkv.dtype == torch.float16:
1319-
qkv = qkv.float()
1320-
1321-
qkv = torch.reshape(
1322-
qkv,
1323-
(
1324-
B,
1325-
-1,
1326-
3 * self.dim,
1327-
H * W,
1328-
),
1329-
)
1330-
q, k, v = (
1331-
qkv[:, :, 0 : self.dim],
1332-
qkv[:, :, self.dim : 2 * self.dim],
1333-
qkv[:, :, 2 * self.dim :],
1334-
)
1335-
1336-
# lightweight linear attention
1337-
q = self.kernel_func(q)
1338-
k = self.kernel_func(k)
1339-
1340-
# linear matmul
1341-
trans_k = k.transpose(-1, -2)
1342-
1343-
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
1344-
vk = torch.matmul(v, trans_k)
1345-
out = torch.matmul(vk, q)
1346-
if out.dtype == torch.bfloat16:
1347-
out = out.float()
1348-
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
1349-
1350-
out = torch.reshape(out, (B, -1, H, W))
1351-
return out
1352-
1353-
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
1354-
B, _, H, W = list(qkv.size())
1355-
1356-
qkv = torch.reshape(
1357-
qkv,
1358-
(
1359-
B,
1360-
-1,
1361-
3 * self.dim,
1362-
H * W,
1363-
),
1364-
)
1365-
q, k, v = (
1366-
qkv[:, :, 0 : self.dim],
1367-
qkv[:, :, self.dim : 2 * self.dim],
1368-
qkv[:, :, 2 * self.dim :],
1369-
)
1370-
1371-
q = self.kernel_func(q)
1372-
k = self.kernel_func(k)
1373-
1374-
att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
1375-
original_dtype = att_map.dtype
1376-
if original_dtype in [torch.float16, torch.bfloat16]:
1377-
att_map = att_map.float()
1378-
att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
1379-
att_map = att_map.to(original_dtype)
1380-
out = torch.matmul(v, att_map) # b h d n
1381-
1382-
out = torch.reshape(out, (B, -1, H, W))
1383-
return out
1384-
1385-
def forward(self, x: torch.Tensor) -> torch.Tensor:
1386-
# generate multi-scale q, k, v
1387-
qkv = self.qkv(x)
1388-
multi_scale_qkv = [qkv]
1389-
for op in self.aggreg:
1390-
multi_scale_qkv.append(op(qkv))
1391-
qkv = torch.cat(multi_scale_qkv, dim=1)
1392-
1393-
H, W = list(qkv.size())[-2:]
1394-
if H * W > self.dim:
1395-
out = self.relu_linear_att(qkv).to(qkv.dtype)
1396-
else:
1397-
out = self.relu_quadratic_att(qkv)
1398-
out = self.proj(out)
1399-
1400-
return x + out

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2+
from .autoencoder_dc import AutoencoderDC
23
from .autoencoder_kl import AutoencoderKL
34
from .autoencoder_kl_allegro import AutoencoderKLAllegro
45
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX

0 commit comments

Comments
 (0)