diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 6ba4b8c6e837..f23e8a207e36 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -92,9 +92,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] token = kwargs.pop("token", None) variant = kwargs.pop("variant", None) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) device_map = kwargs.pop("device_map", None) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.") + alpha = kwargs.pop("alpha", 0.5) interp = kwargs.pop("interp", None) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c87d2a7cf8da..fdfbb923bae8 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -360,11 +360,17 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: cache_dir = kwargs.pop("cache_dir", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) disable_mmap = kwargs.pop("disable_mmap", False) is_legacy_loading = False + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + # We shouldn't allow configuring individual models components through a Pipeline creation method # These model kwargs should be deprecated scaling_factor = kwargs.get("scaling_factor", None) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b6eaffbc8c80..e6b050833485 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -240,11 +240,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) config_revision = kwargs.pop("config_revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e7f306da6bc4..4fbbd78667e3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) @@ -879,6 +879,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + allow_pickle = False if use_safetensors is None: use_safetensors = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 26bd938b2734..2ef67aa24e8a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -684,7 +684,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) provider = kwargs.pop("provider", None) @@ -701,6 +701,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( @@ -1829,7 +1835,7 @@ def from_pipe(cls, pipeline, **kwargs): """ original_config = dict(pipeline.config) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) # derive the pipeline class to instantiate custom_pipeline = kwargs.pop("custom_pipeline", None) diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index cf0b392ddc06..edeb5884144c 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -89,7 +89,9 @@ def get_dummy_components(self, time_cond_proj_dim=None): sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = { diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py index 025bcf2fac74..9c43e0920e03 100644 --- a/tests/pipelines/kolors/test_kolors_img2img.py +++ b/tests/pipelines/kolors/test_kolors_img2img.py @@ -93,7 +93,9 @@ def get_dummy_components(self, time_cond_proj_dim=None): sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = { diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py index 9a5764e24f59..f6d7331b1ad3 100644 --- a/tests/pipelines/pag/test_pag_kolors.py +++ b/tests/pipelines/pag/test_pag_kolors.py @@ -98,7 +98,9 @@ def get_dummy_components(self, time_cond_proj_dim=None): sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = {