File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -526,7 +526,7 @@ def test_moving_to_cpu_throws_warning(self):
526526 reason = "Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release." ,
527527 strict = True ,
528528 )
529- def test_pipeline_device_placement_works_with_nf4 (self ):
529+ def test_pipeline_cuda_placement_works_with_nf4 (self ):
530530 transformer_nf4_config = BitsAndBytesConfig (
531531 load_in_4bit = True ,
532532 bnb_4bit_quant_type = "nf4" ,
@@ -560,7 +560,7 @@ def test_pipeline_device_placement_works_with_nf4(self):
560560 ).to (torch_device )
561561
562562 # Check if inference works.
563- _ = pipeline_4bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
563+ _ = pipeline_4bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
564564
565565 del pipeline_4bit
566566
Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def test_generate_quality_dequantize(self):
492492 self .assertTrue (max_diff < 1e-2 )
493493
494494 # 8bit models cannot be offloaded to CPU.
495- self .assertTrue (self .pipeline_8bit .transformer .device .type == "cuda" )
495+ self .assertTrue (self .pipeline_8bit .transformer .device .type == torch_device )
496496 # calling it again shouldn't be a problem
497497 _ = self .pipeline_8bit (
498498 prompt = self .prompt ,
@@ -534,7 +534,7 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self):
534534 ).to (device )
535535
536536 # Check if inference works.
537- _ = pipeline_8bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
537+ _ = pipeline_8bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
538538
539539 del pipeline_8bit
540540
You can’t perform that action at this time.
0 commit comments