Skip to content

Commit bf6c211

Browse files
committed
update
1 parent 7b9d7e5 commit bf6c211

File tree

2 files changed

+170
-294
lines changed

2 files changed

+170
-294
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2626
"point_conv": "conv_point",
2727
"inverted_conv": "conv_inverted",
2828
"conv.conv.": "conv.",
29+
"conv1.conv": "conv1",
30+
"conv2.conv": "conv2",
31+
"conv1.norm": "norm2",
32+
"conv2.norm": "norm2",
33+
"qkv.conv": "qkv",
34+
"proj.conv": "proj_out",
35+
"proj.norm": "norm_out",
2936
# encoder
3037
"encoder.project_in.conv": "encoder.conv_in",
3138
"encoder.project_out.0.conv": "encoder.conv_out",
@@ -114,6 +121,74 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
114121
return vae
115122

116123

124+
def get_vae_config(name: str):
125+
if name in ["dc-ae-f32c32-sana-1.0"]:
126+
config = {
127+
"latent_channels": 32,
128+
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
129+
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
130+
"encoder_layers_per_block": [2, 2, 2, 3, 3, 3],
131+
"downsample_block_type": "Conv",
132+
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
133+
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
134+
"upsample_block_type": "InterpolateConv",
135+
"scaling_factor": 0.41407,
136+
}
137+
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
138+
config = {
139+
"latent_channels": 32,
140+
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
141+
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
142+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
143+
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
144+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
145+
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d"],
146+
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu"],
147+
}
148+
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
149+
config = {
150+
"latent_channels": 512,
151+
"encoder_block_type": [
152+
"ResBlock",
153+
"ResBlock",
154+
"ResBlock",
155+
"EViT_GLU",
156+
"EViT_GLU",
157+
"EViT_GLU",
158+
"EViT_GLU",
159+
"EViT_GLU",
160+
],
161+
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
162+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
163+
"decoder_block_type": [
164+
"ResBlock",
165+
"ResBlock",
166+
"ResBlock",
167+
"EViT_GLU",
168+
"EViT_GLU",
169+
"EViT_GLU",
170+
"EViT_GLU",
171+
"EViT_GLU",
172+
],
173+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
174+
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d", "rms2d", "rms2d"],
175+
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
176+
}
177+
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
178+
config = {
179+
"latent_channels": 128,
180+
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
181+
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
182+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
183+
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
184+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
185+
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d", "rms2d"],
186+
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
187+
}
188+
189+
return config
190+
191+
117192
def get_args():
118193
parser = argparse.ArgumentParser()
119194
parser.add_argument(

0 commit comments

Comments
 (0)