@@ -99,7 +99,7 @@ def convert_ae(config_name: str, dtype: torch.dtype):
9999 hub_id = f"mit-han-lab/{ config_name } "
100100 ckpt_path = hf_hub_download (hub_id , "model.safetensors" )
101101 original_state_dict = get_state_dict (load_file (ckpt_path ))
102-
102+
103103 ae = AutoencoderDC (** config ).to (dtype = dtype )
104104
105105 for key in list (original_state_dict .keys ()):
@@ -122,8 +122,22 @@ def get_ae_config(name: str):
122122 if name in ["dc-ae-f32c32-sana-1.0" ]:
123123 config = {
124124 "latent_channels" : 32 ,
125- "encoder_block_types" : ("ResBlock" , "ResBlock" , "ResBlock" , "EfficientViTBlock" , "EfficientViTBlock" , "EfficientViTBlock" ),
126- "decoder_block_types" : ("ResBlock" , "ResBlock" , "ResBlock" , "EfficientViTBlock" , "EfficientViTBlock" , "EfficientViTBlock" ),
125+ "encoder_block_types" : (
126+ "ResBlock" ,
127+ "ResBlock" ,
128+ "ResBlock" ,
129+ "EfficientViTBlock" ,
130+ "EfficientViTBlock" ,
131+ "EfficientViTBlock" ,
132+ ),
133+ "decoder_block_types" : (
134+ "ResBlock" ,
135+ "ResBlock" ,
136+ "ResBlock" ,
137+ "EfficientViTBlock" ,
138+ "EfficientViTBlock" ,
139+ "EfficientViTBlock" ,
140+ ),
127141 "encoder_block_out_channels" : (128 , 256 , 512 , 512 , 1024 , 1024 ),
128142 "decoder_block_out_channels" : (128 , 256 , 512 , 512 , 1024 , 1024 ),
129143 "encoder_qkv_multiscales" : ((), (), (), (5 ,), (5 ,), (5 ,)),
0 commit comments