@@ -26,6 +26,13 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2626 "point_conv" : "conv_point" ,
2727 "inverted_conv" : "conv_inverted" ,
2828 "conv.conv." : "conv." ,
29+ "conv1.conv" : "conv1" ,
30+ "conv2.conv" : "conv2" ,
31+ "conv1.norm" : "norm2" ,
32+ "conv2.norm" : "norm2" ,
33+ "qkv.conv" : "qkv" ,
34+ "proj.conv" : "proj_out" ,
35+ "proj.norm" : "norm_out" ,
2936 # encoder
3037 "encoder.project_in.conv" : "encoder.conv_in" ,
3138 "encoder.project_out.0.conv" : "encoder.conv_out" ,
@@ -114,6 +121,74 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
114121 return vae
115122
116123
124+ def get_vae_config (name : str ):
125+ if name in ["dc-ae-f32c32-sana-1.0" ]:
126+ config = {
127+ "latent_channels" : 32 ,
128+ "encoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
129+ "block_out_channels" : [128 , 256 , 512 , 512 , 1024 , 1024 ],
130+ "encoder_layers_per_block" : [2 , 2 , 2 , 3 , 3 , 3 ],
131+ "downsample_block_type" : "Conv" ,
132+ "decoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
133+ "decoder_layers_per_block" : [3 , 3 , 3 , 3 , 3 , 3 ],
134+ "upsample_block_type" : "InterpolateConv" ,
135+ "scaling_factor" : 0.41407 ,
136+ }
137+ elif name in ["dc-ae-f32c32-in-1.0" , "dc-ae-f32c32-mix-1.0" ]:
138+ config = {
139+ "latent_channels" : 32 ,
140+ "encoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
141+ "block_out_channels" : [128 , 256 , 512 , 512 , 1024 , 1024 ],
142+ "encoder_layers_per_block" : [0 , 4 , 8 , 2 , 2 , 2 ],
143+ "decoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
144+ "decoder_layers_per_block" : [0 , 5 , 10 , 2 , 2 , 2 ],
145+ "decoder_norm" : ["bn2d" , "bn2d" , "bn2d" , "rms2d" , "rms2d" , "rms2d" ],
146+ "decoder_act" : ["relu" , "relu" , "relu" , "silu" , "silu" , "silu" ],
147+ }
148+ elif name in ["dc-ae-f128c512-in-1.0" , "dc-ae-f128c512-mix-1.0" ]:
149+ config = {
150+ "latent_channels" : 512 ,
151+ "encoder_block_type" : [
152+ "ResBlock" ,
153+ "ResBlock" ,
154+ "ResBlock" ,
155+ "EViT_GLU" ,
156+ "EViT_GLU" ,
157+ "EViT_GLU" ,
158+ "EViT_GLU" ,
159+ "EViT_GLU" ,
160+ ],
161+ "block_out_channels" : [128 , 256 , 512 , 512 , 1024 , 1024 , 2048 , 2048 ],
162+ "encoder_layers_per_block" : [0 , 4 , 8 , 2 , 2 , 2 , 2 , 2 ],
163+ "decoder_block_type" : [
164+ "ResBlock" ,
165+ "ResBlock" ,
166+ "ResBlock" ,
167+ "EViT_GLU" ,
168+ "EViT_GLU" ,
169+ "EViT_GLU" ,
170+ "EViT_GLU" ,
171+ "EViT_GLU" ,
172+ ],
173+ "decoder_layers_per_block" : [0 , 5 , 10 , 2 , 2 , 2 , 2 , 2 ],
174+ "decoder_norm" : ["bn2d" , "bn2d" , "bn2d" , "rms2d" , "rms2d" , "rms2d" , "rms2d" , "rms2d" ],
175+ "decoder_act" : ["relu" , "relu" , "relu" , "silu" , "silu" , "silu" , "silu" , "silu" ],
176+ }
177+ elif name in ["dc-ae-f64c128-in-1.0" , "dc-ae-f64c128-mix-1.0" ]:
178+ config = {
179+ "latent_channels" : 128 ,
180+ "encoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
181+ "block_out_channels" : [128 , 256 , 512 , 512 , 1024 , 1024 , 2048 ],
182+ "encoder_layers_per_block" : [0 , 4 , 8 , 2 , 2 , 2 , 2 ],
183+ "decoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
184+ "decoder_layers_per_block" : [0 , 5 , 10 , 2 , 2 , 2 , 2 ],
185+ "decoder_norm" : ["bn2d" , "bn2d" , "bn2d" , "rms2d" , "rms2d" , "rms2d" , "rms2d" ],
186+ "decoder_act" : ["relu" , "relu" , "relu" , "silu" , "silu" , "silu" , "silu" ],
187+ }
188+
189+ return config
190+
191+
117192def get_args ():
118193 parser = argparse .ArgumentParser ()
119194 parser .add_argument (
0 commit comments