@@ -685,7 +685,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
685685 token = kwargs .pop ("token" , None )
686686 revision = kwargs .pop ("revision" , None )
687687 from_flax = kwargs .pop ("from_flax" , False )
688- torch_dtype = kwargs .pop ("torch_dtype" , None )
688+ torch_dtype = kwargs .pop ("torch_dtype" , torch . float32 )
689689 custom_pipeline = kwargs .pop ("custom_pipeline" , None )
690690 custom_revision = kwargs .pop ("custom_revision" , None )
691691 provider = kwargs .pop ("provider" , None )
@@ -702,6 +702,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
702702 use_onnx = kwargs .pop ("use_onnx" , None )
703703 load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
704704
705+ if not isinstance (torch_dtype , torch .dtype ):
706+ torch_dtype = torch .float32
707+ logger .warning (
708+ f"Passed `torch_dtype` { torch_dtype } is not a `torch.dtype`. Defaulting to `torch.float32`."
709+ )
710+
705711 if low_cpu_mem_usage and not is_accelerate_available ():
706712 low_cpu_mem_usage = False
707713 logger .warning (
@@ -1826,7 +1832,7 @@ def from_pipe(cls, pipeline, **kwargs):
18261832 """
18271833
18281834 original_config = dict (pipeline .config )
1829- torch_dtype = kwargs .pop ("torch_dtype" , None )
1835+ torch_dtype = kwargs .pop ("torch_dtype" , torch . float32 )
18301836
18311837 # derive the pipeline class to instantiate
18321838 custom_pipeline = kwargs .pop ("custom_pipeline" , None )
0 commit comments