Skip to content

Commit 8bf12e8

Browse files
committed
add a test
1 parent 5801679 commit 8bf12e8

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,20 @@ def test_lora_loading(self):
752752
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
753753
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
754754

755+
def test_loading_lora_with_incorrect_dtype_raises_error(self):
756+
self.tearDown()
757+
model_dtype = torch.bfloat16
758+
# https://huggingface.co/eramth/flux-4bit/blob/main/transformer/config.json#L23
759+
actual_dtype = torch.float16
760+
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16)
761+
self.pipeline_4bit.enable_model_cpu_offload()
762+
with self.assertRaises(ValueError) as err_context:
763+
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
764+
assert (
765+
f"Model is in {model_dtype} dtype while the current module weight will be dequantized to {actual_dtype} dtype."
766+
in str(err_context.exception)
767+
)
768+
755769

756770
@slow
757771
class BaseBnb4BitSerializationTests(Base4bitTests):

0 commit comments

Comments
 (0)