1616import  argparse 
1717import  logging 
1818import  sys 
19+ import  time  as  time 
1920from  collections .abc  import  Callable 
2021from  dataclasses  import  dataclass 
2122from  enum  import  Enum 
@@ -59,6 +60,7 @@ class ModelType(str, Enum):
5960    SDXL_BASE  =  "sdxl-1.0" 
6061    SDXL_TURBO  =  "sdxl-turbo" 
6162    SD3_MEDIUM  =  "sd3-medium" 
63+     SD35_MEDIUM  =  "sd3.5-medium" 
6264    FLUX_DEV  =  "flux-dev" 
6365    FLUX_SCHNELL  =  "flux-schnell" 
6466    LTX_VIDEO_DEV  =  "ltx-video-dev" 
@@ -114,6 +116,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
114116        ModelType .SDXL_BASE : filter_func_default ,
115117        ModelType .SDXL_TURBO : filter_func_default ,
116118        ModelType .SD3_MEDIUM : filter_func_default ,
119+         ModelType .SD35_MEDIUM : filter_func_default ,
117120        ModelType .LTX_VIDEO_DEV : filter_func_ltx_video ,
118121    }
119122
@@ -125,6 +128,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
125128    ModelType .SDXL_BASE : "stabilityai/stable-diffusion-xl-base-1.0" ,
126129    ModelType .SDXL_TURBO : "stabilityai/sdxl-turbo" ,
127130    ModelType .SD3_MEDIUM : "stabilityai/stable-diffusion-3-medium-diffusers" ,
131+     ModelType .SD35_MEDIUM : "stabilityai/stable-diffusion-3.5-medium" ,
128132    ModelType .FLUX_DEV : "black-forest-labs/FLUX.1-dev" ,
129133    ModelType .FLUX_SCHNELL : "black-forest-labs/FLUX.1-schnell" ,
130134    ModelType .LTX_VIDEO_DEV : "Lightricks/LTX-Video-0.9.7-dev" ,
@@ -230,6 +234,7 @@ def uses_transformer(self) -> bool:
230234        """Check if model uses transformer backbone (vs UNet).""" 
231235        return  self .model_type  in  [
232236            ModelType .SD3_MEDIUM ,
237+             ModelType .SD35_MEDIUM ,
233238            ModelType .FLUX_DEV ,
234239            ModelType .FLUX_SCHNELL ,
235240            ModelType .LTX_VIDEO_DEV ,
@@ -326,7 +331,7 @@ def create_pipeline_from(
326331            model_id  =  (
327332                MODEL_REGISTRY [model_type ] if  override_model_path  is  None  else  override_model_path 
328333            )
329-             if  model_type  ==   ModelType .SD3_MEDIUM :
334+             if  model_type  in  [ ModelType .SD3_MEDIUM ,  ModelType . SD35_MEDIUM ] :
330335                pipe  =  StableDiffusion3Pipeline .from_pretrained (model_id , torch_dtype = torch_dtype )
331336            elif  model_type  in  [ModelType .FLUX_DEV , ModelType .FLUX_SCHNELL ]:
332337                pipe  =  FluxPipeline .from_pretrained (model_id , torch_dtype = torch_dtype )
@@ -357,7 +362,7 @@ def create_pipeline(self) -> DiffusionPipeline:
357362        self .logger .info (f"Data type: { self .config .model_dtype .value }  " )
358363
359364        try :
360-             if  self .config .model_type  ==   ModelType .SD3_MEDIUM :
365+             if  self .config .model_type  in  [ ModelType .SD3_MEDIUM ,  ModelType . SD35_MEDIUM ] :
361366                self .pipe  =  StableDiffusion3Pipeline .from_pretrained (
362367                    self .config .model_path , torch_dtype = self .config .torch_dtype 
363368                )
@@ -864,6 +869,8 @@ def main() -> None:
864869    parser  =  create_argument_parser ()
865870    args  =  parser .parse_args ()
866871
872+     s  =  time .time ()
873+ 
867874    logger  =  setup_logging (args .verbose )
868875    logger .info ("Starting Enhanced Diffusion Model Quantization" )
869876
@@ -939,9 +946,11 @@ def forward_loop(mod):
939946            backbone ,
940947            model_config .model_type ,
941948            quant_config .format ,
942-             quantize_mha = QuantizationConfig .quantize_mha ,
949+             quantize_mha = quant_config .quantize_mha ,
950+         )
951+         logger .info (
952+             f"Quantization process completed successfully! Time taken = { time .time () -  s }   seconds" 
943953        )
944-         logger .info ("Quantization process completed successfully!" )
945954
946955    except  Exception  as  e :
947956        logger .error (f"Quantization failed: { e }  " , exc_info = True )
0 commit comments