@@ -1023,17 +1023,89 @@ def get_diffusion_models_for_export_ext(
10231023        is_flux  =  isinstance (pipeline , tuple (flux_pipes ))
10241024    else :
10251025        is_flux  =  False 
1026+     
1027+     try :
1028+         from  diffusers  import  SanaPipeline 
1029+         is_sana  =  isinstance (pipeline , SanaPipeline )
1030+     except  ImportError :
1031+         is_sana  =  False 
10261032
1027-     if  not  is_sd3   and   not   is_flux :
1033+     if  not  any ([ is_sana ,  is_flux ,  is_sd3 ]) :
10281034        return  None , get_diffusion_models_for_export (pipeline , int_dtype , float_dtype , exporter )
10291035    if  is_sd3 :
10301036        models_for_export  =  get_sd3_models_for_export (pipeline , exporter , int_dtype , float_dtype )
1037+     if  is_sana :
1038+         models_for_export  =  get_sana_models_for_export (pipeline , exporter , int_dtype , float_dtype )
10311039    else :
10321040        models_for_export  =  get_flux_models_for_export (pipeline , exporter , int_dtype , float_dtype )
10331041
10341042    return  None , models_for_export 
10351043
10361044
1045+ def  get_sana_models_for_export (pipeline , exporter , int_dtype , float_dtype ):
1046+     DEFAULT_DUMMY_SHAPES ["heigh" ] =  DEFAULT_DUMMY_SHAPES ["height" ] //  4 
1047+     DEFAULT_DUMMY_SHAPES ["width" ] =  DEFAULT_DUMMY_SHAPES ["width" ] //  4 
1048+     models_for_export  =  {}
1049+     text_encoder  =  pipeline .text_encoder 
1050+     text_encoder_config_constructor  =  TasksManager .get_exporter_config_constructor (
1051+             model = text_encoder ,
1052+             exporter = exporter ,
1053+             library_name = "diffusers" ,
1054+             task = "feature-extraction" ,
1055+             model_type = "gemma2-text-encoder" ,
1056+         )
1057+     text_encoder_export_config  =  text_encoder_config_constructor (
1058+         pipeline .text_encoder .config , int_dtype = int_dtype , float_dtype = float_dtype 
1059+     )
1060+     models_for_export ["text_encoder" ] =  (text_encoder , text_encoder_export_config )
1061+     transformer  =  pipeline .transformer 
1062+     transformer .config .text_encoder_projection_dim  =  transformer .config .caption_channels 
1063+     transformer .config .requires_aesthetics_score  =  False 
1064+     transformer .config .time_cond_proj_dim  =  None 
1065+     export_config_constructor  =  TasksManager .get_exporter_config_constructor (
1066+         model = transformer ,
1067+         exporter = exporter ,
1068+         library_name = "diffusers" ,
1069+         task = "semantic-segmentation" ,
1070+         model_type = "sana-transformer" ,
1071+     )
1072+     transformer_export_config  =  export_config_constructor (
1073+         pipeline .transformer .config , int_dtype = int_dtype , float_dtype = float_dtype 
1074+     )
1075+     models_for_export ["transformer" ] =  (transformer , transformer_export_config )
1076+     # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 
1077+     vae_encoder  =  copy .deepcopy (pipeline .vae )
1078+     vae_encoder .forward  =  lambda  sample : {"latent_parameters" : vae_encoder .encode (x = sample )["latent_dist" ].parameters }
1079+     vae_config_constructor  =  TasksManager .get_exporter_config_constructor (
1080+         model = vae_encoder ,
1081+         exporter = exporter ,
1082+         library_name = "diffusers" ,
1083+         task = "semantic-segmentation" ,
1084+         model_type = "vae-encoder" ,
1085+     )
1086+     vae_encoder_export_config  =  vae_config_constructor (
1087+         vae_encoder .config , int_dtype = int_dtype , float_dtype = float_dtype 
1088+     )
1089+     models_for_export ["vae_encoder" ] =  (vae_encoder , vae_encoder_export_config )
1090+ 
1091+     # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 
1092+     vae_decoder  =  copy .deepcopy (pipeline .vae )
1093+     vae_decoder .forward  =  lambda  latent_sample : vae_decoder .decode (z = latent_sample )
1094+     vae_config_constructor  =  TasksManager .get_exporter_config_constructor (
1095+         model = vae_decoder ,
1096+         exporter = exporter ,
1097+         library_name = "diffusers" ,
1098+         task = "semantic-segmentation" ,
1099+         model_type = "vae-decoder" ,
1100+     )
1101+     vae_decoder_export_config  =  vae_config_constructor (
1102+         vae_decoder .config , int_dtype = int_dtype , float_dtype = float_dtype 
1103+     )
1104+     models_for_export ["vae_decoder" ] =  (vae_decoder , vae_decoder_export_config )
1105+ 
1106+     return  models_for_export 
1107+ 
1108+ 
10371109def  get_sd3_models_for_export (pipeline , exporter , int_dtype , float_dtype ):
10381110    models_for_export  =  {}
10391111
0 commit comments