Skip to content

Commit c87646a

Browse files
committed
Don't use torch_dtype when quantization_config is set
1 parent 5428046 commit c87646a

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
266266
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
267267
)
268268

269+
if quantization_config is not None and torch_dtype is not None:
270+
torch_dtype = None
271+
269272
if isinstance(pretrained_model_link_or_path_or_dict, dict):
270273
checkpoint = pretrained_model_link_or_path_or_dict
271274
else:

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
889889
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
890890
)
891891

892+
if quantization_config is not None and torch_dtype is not None:
893+
torch_dtype = None
894+
892895
allow_pickle = False
893896
if use_safetensors is None:
894897
use_safetensors = True

0 commit comments

Comments
 (0)