@@ -685,7 +685,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
685685        token  =  kwargs .pop ("token" , None )
686686        revision  =  kwargs .pop ("revision" , None )
687687        from_flax  =  kwargs .pop ("from_flax" , False )
688-         torch_dtype  =  kwargs .pop ("torch_dtype" , None )
688+         torch_dtype  =  kwargs .pop ("torch_dtype" , torch . float32 )
689689        custom_pipeline  =  kwargs .pop ("custom_pipeline" , None )
690690        custom_revision  =  kwargs .pop ("custom_revision" , None )
691691        provider  =  kwargs .pop ("provider" , None )
@@ -702,6 +702,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
702702        use_onnx  =  kwargs .pop ("use_onnx" , None )
703703        load_connected_pipeline  =  kwargs .pop ("load_connected_pipeline" , False )
704704
705+         if  not  isinstance (torch_dtype , torch .dtype ):
706+             torch_dtype  =  torch .float32 
707+             logger .warning (
708+                 f"Passed `torch_dtype` { torch_dtype }  
709+             )
710+ 
705711        if  low_cpu_mem_usage  and  not  is_accelerate_available ():
706712            low_cpu_mem_usage  =  False 
707713            logger .warning (
@@ -1826,7 +1832,7 @@ def from_pipe(cls, pipeline, **kwargs):
18261832        """ 
18271833
18281834        original_config  =  dict (pipeline .config )
1829-         torch_dtype  =  kwargs .pop ("torch_dtype" , None )
1835+         torch_dtype  =  kwargs .pop ("torch_dtype" , torch . float32 )
18301836
18311837        # derive the pipeline class to instantiate 
18321838        custom_pipeline  =  kwargs .pop ("custom_pipeline" , None )
0 commit comments