@@ -28,7 +28,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2828 "conv.conv." : "conv." ,
2929 "conv1.conv" : "conv1" ,
3030 "conv2.conv" : "conv2" ,
31- "conv1.norm" : "norm2 " ,
31+ "conv1.norm" : "norm1 " ,
3232 "conv2.norm" : "norm2" ,
3333 "qkv.conv" : "qkv" ,
3434 "proj.conv" : "proj_out" ,
@@ -90,14 +90,11 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
9090 vae = AutoencoderDC (
9191 in_channels = 3 ,
9292 latent_channels = 32 ,
93- encoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
94- encoder_depth_list = [2 , 2 , 2 , 3 , 3 , 3 ],
93+ block_out_channels = [128 , 256 , 512 , 512 , 1024 , 1024 ],
94+ encoder_layers_per_block = [2 , 2 , 2 , 3 , 3 , 3 ],
9595 encoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
96- encoder_norm = "rms2d" ,
97- encoder_act = "silu" ,
9896 downsample_block_type = "Conv" ,
99- decoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
100- decoder_depth_list = [3 , 3 , 3 , 3 , 3 , 3 ],
97+ decoder_layers_per_block = [3 , 3 , 3 , 3 , 3 , 3 ],
10198 decoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
10299 decoder_norm = "rms2d" ,
103100 decoder_act = "silu" ,
0 commit comments