diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bd5584296a0f..acc6d30b793a 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -526,7 +526,7 @@ def test_moving_to_cpu_throws_warning(self): reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.", strict=True, ) - def test_pipeline_device_placement_works_with_nf4(self): + def test_pipeline_cuda_placement_works_with_nf4(self): transformer_nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", @@ -560,7 +560,7 @@ def test_pipeline_device_placement_works_with_nf4(self): ).to(torch_device) # Check if inference works. - _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2) + _ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2) del pipeline_4bit diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index f9870977999a..7abb907ff9cf 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -492,7 +492,7 @@ def test_generate_quality_dequantize(self): self.assertTrue(max_diff < 1e-2) # 8bit models cannot be offloaded to CPU. - self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device) # calling it again shouldn't be a problem _ = self.pipeline_8bit( prompt=self.prompt, @@ -534,7 +534,7 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self): ).to(device) # Check if inference works. - _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) + _ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2) del pipeline_8bit