Skip to content

Commit 98d0cd5

Browse files
a-r-r-o-wsayakpaul
andauthored
Use torch.device instead of current device index for BnB quantizer (#10069)
* update * apply review suggestion --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 0d11ab2 commit 98d0cd5

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
176176
hf_quantizer=None,
177177
keep_in_fp32_modules=None,
178178
) -> List[str]:
179+
if device is not None and not isinstance(device, (str, torch.device)):
180+
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
179181
if hf_quantizer is None:
180182
device = device or torch.device("cpu")
181183
dtype = dtype or torch.float32

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
836836
param_device = "cpu"
837837
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
838838
elif is_quant_method_bnb:
839-
param_device = torch.cuda.current_device()
839+
param_device = torch.device(torch.cuda.current_device())
840840
state_dict = load_state_dict(model_file, variant=variant)
841841
model._convert_deprecated_attention_blocks(state_dict)
842842

0 commit comments

Comments
 (0)