Skip to content

Commit 0ae70fe

Browse files
committed
fix dtype checks in pipeline.
1 parent 700b0f3 commit 0ae70fe

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def module_is_offloaded(module):
450450
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
451451
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
452452
module.to(device=device)
453-
elif not is_loaded_in_8bit_bnb:
453+
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
454454
module.to(device, dtype)
455455

456456
module_has_int_weights = any(

tests/quantization/bnb/test_4bit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_device_assignment(self):
260260
def test_device_and_dtype_assignment(self):
261261
r"""
262262
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
263-
Checks also if other models are casted correctly.
263+
Checks also if other models are casted correctly. Device placement, however, is supported.
264264
"""
265265
with self.assertRaises(ValueError):
266266
# Tries with a `dtype`
@@ -278,6 +278,9 @@ def test_device_and_dtype_assignment(self):
278278
# Tries with a cast
279279
self.model_4bit.half()
280280

281+
# This should work
282+
self.model_4bit.to("cuda")
283+
281284
# Test if we did not break anything
282285
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
283286
input_dict_for_transformer = self.get_dummy_inputs()

0 commit comments

Comments
 (0)