99from accelerate import init_empty_weights
1010from diffusers import (
1111 DCAE ,
12- DCAE_HF ,
13- FlowDPMSolverMultistepScheduler ,
12+ DPMSolverMultistepScheduler ,
1413 FlowMatchEulerDiscreteScheduler ,
1514 SanaPAGPipeline ,
1615 SanaTransformer2DModel ,
@@ -186,27 +185,10 @@ def main(args):
186185 else :
187186 print (colored (f"Saving the whole SanaPAGPipeline containing { args .model_type } " , "green" , attrs = ["bold" ]))
188187 # VAE
189- dc_ae = DCAE_HF .from_pretrained (f"mit-han-lab/dc-ae-f32c32-sana-1.0" )
190- dc_ae_state_dict = dc_ae .state_dict ()
191- dc_ae = DCAE (
192- in_channels = 3 ,
193- latent_channels = 32 ,
194- encoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
195- encoder_depth_list = [2 , 2 , 2 , 3 , 3 , 3 ],
196- encoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
197- encoder_norm = "rms2d" ,
198- encoder_act = "silu" ,
199- downsample_block_type = "Conv" ,
200- decoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
201- decoder_depth_list = [3 , 3 , 3 , 3 , 3 , 3 ],
202- decoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
203- decoder_norm = "rms2d" ,
204- decoder_act = "silu" ,
205- upsample_block_type = "InterpolateConv" ,
206- scaling_factor = 0.41407 ,
207- )
208- dc_ae .load_state_dict (dc_ae_state_dict , strict = True )
209- dc_ae .to (torch .float32 ).to (device )
188+ dc_ae = DCAE .from_pretrained (
189+ "Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers" ,
190+ torch_dtype = torch .float32 ,
191+ ).to (device )
210192
211193 # Text Encoder
212194 text_encoder_model_path = "google/gemma-2-2b-it"
@@ -220,7 +202,11 @@ def main(args):
220202
221203 # Scheduler
222204 if args .scheduler_type == "flow-dpm_solver" :
223- scheduler = FlowDPMSolverMultistepScheduler (flow_shift = flow_shift )
205+ scheduler = DPMSolverMultistepScheduler (
206+ flow_shift = flow_shift ,
207+ use_flow_sigmas = True ,
208+ prediction_type = "flow_prediction" ,
209+ )
224210 elif args .scheduler_type == "flow-euler" :
225211 scheduler = FlowMatchEulerDiscreteScheduler (shift = flow_shift )
226212 else :
0 commit comments