Skip to content

Commit 78c7861

Browse files
committed
update
1 parent aa7659b commit 78c7861

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
205205
revision = kwargs.pop("revision", None)
206206
torch_dtype = kwargs.pop("torch_dtype", None)
207207
quantization_config = kwargs.pop("quantization_config", None)
208+
device = kwargs.pop("device", None)
208209

209210
if isinstance(pretrained_model_link_or_path_or_dict, dict):
210211
checkpoint = pretrained_model_link_or_path_or_dict
@@ -326,10 +327,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
326327
)
327328

328329
if is_accelerate_available():
330+
param_device = torch.device(device) if device else torch.device("cpu")
329331
unexpected_keys = load_model_dict_into_meta(
330332
model,
331333
diffusers_format_checkpoint,
332334
dtype=torch_dtype,
335+
device=param_device,
333336
hf_quantizer=hf_quantizer,
334337
keep_in_fp32_modules=keep_in_fp32_modules,
335338
)

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def load_model_dict_into_meta(
184184
) -> List[str]:
185185
if device is not None and not isinstance(device, (str, torch.device)):
186186
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
187-
device = device or torch.device("cpu")
187+
if hf_quantizer is None:
188+
device = device or torch.device("cpu")
188189
dtype = dtype or torch.float32
189190
is_quantized = hf_quantizer is not None
190191

0 commit comments

Comments
 (0)