Skip to content

Commit f5876c5

Browse files
committed
fix
1 parent a2ec5f8 commit f5876c5

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2828
"conv.conv.": "conv.",
2929
"conv1.conv": "conv1",
3030
"conv2.conv": "conv2",
31-
"conv1.norm": "norm2",
31+
"conv1.norm": "norm1",
3232
"conv2.norm": "norm2",
3333
"qkv.conv": "qkv",
3434
"proj.conv": "proj_out",
@@ -90,14 +90,11 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
9090
vae = AutoencoderDC(
9191
in_channels=3,
9292
latent_channels=32,
93-
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
94-
encoder_depth_list=[2, 2, 2, 3, 3, 3],
93+
block_out_channels=[128, 256, 512, 512, 1024, 1024],
94+
encoder_layers_per_block=[2, 2, 2, 3, 3, 3],
9595
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
96-
encoder_norm="rms2d",
97-
encoder_act="silu",
9896
downsample_block_type="Conv",
99-
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
100-
decoder_depth_list=[3, 3, 3, 3, 3, 3],
97+
decoder_layers_per_block=[3, 3, 3, 3, 3, 3],
10198
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
10299
decoder_norm="rms2d",
103100
decoder_act="silu",

0 commit comments

Comments
 (0)