99from  accelerate  import  init_empty_weights 
1010from  diffusers  import  (
1111    DCAE ,
12-     DCAE_HF ,
13-     FlowDPMSolverMultistepScheduler ,
12+     DPMSolverMultistepScheduler ,
1413    FlowMatchEulerDiscreteScheduler ,
1514    SanaPAGPipeline ,
1615    SanaTransformer2DModel ,
@@ -186,27 +185,10 @@ def main(args):
186185    else :
187186        print (colored (f"Saving the whole SanaPAGPipeline containing { args .model_type }  , "green" , attrs = ["bold" ]))
188187        # VAE 
189-         dc_ae  =  DCAE_HF .from_pretrained (f"mit-han-lab/dc-ae-f32c32-sana-1.0" )
190-         dc_ae_state_dict  =  dc_ae .state_dict ()
191-         dc_ae  =  DCAE (
192-             in_channels = 3 ,
193-             latent_channels = 32 ,
194-             encoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
195-             encoder_depth_list = [2 , 2 , 2 , 3 , 3 , 3 ],
196-             encoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
197-             encoder_norm = "rms2d" ,
198-             encoder_act = "silu" ,
199-             downsample_block_type = "Conv" ,
200-             decoder_width_list = [128 , 256 , 512 , 512 , 1024 , 1024 ],
201-             decoder_depth_list = [3 , 3 , 3 , 3 , 3 , 3 ],
202-             decoder_block_type = ["ResBlock" , "ResBlock" , "ResBlock" , "EViTS5_GLU" , "EViTS5_GLU" , "EViTS5_GLU" ],
203-             decoder_norm = "rms2d" ,
204-             decoder_act = "silu" ,
205-             upsample_block_type = "InterpolateConv" ,
206-             scaling_factor = 0.41407 ,
207-         )
208-         dc_ae .load_state_dict (dc_ae_state_dict , strict = True )
209-         dc_ae .to (torch .float32 ).to (device )
188+         dc_ae  =  DCAE .from_pretrained (
189+             "Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers" ,
190+             torch_dtype = torch .float32 ,
191+         ).to (device )
210192
211193        # Text Encoder 
212194        text_encoder_model_path  =  "google/gemma-2-2b-it" 
@@ -220,7 +202,11 @@ def main(args):
220202
221203        # Scheduler 
222204        if  args .scheduler_type  ==  "flow-dpm_solver" :
223-             scheduler  =  FlowDPMSolverMultistepScheduler (flow_shift = flow_shift )
205+             scheduler  =  DPMSolverMultistepScheduler (
206+                 flow_shift = flow_shift , 
207+                 use_flow_sigmas = True ,
208+                 prediction_type = "flow_prediction" ,
209+             )
224210        elif  args .scheduler_type  ==  "flow-euler" :
225211            scheduler  =  FlowMatchEulerDiscreteScheduler (shift = flow_shift )
226212        else :
0 commit comments