@@ -447,21 +447,19 @@ def test_wrong_config(self):
447447            self .get_dummy_components (TorchAoConfig ("int42" ))
448448
449449
450- # This class is not to be run as a test by itself. See the tests that follow this class  
450+ # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners  
451451@require_torch  
452452@require_torch_gpu  
453453@require_torchao_version_greater_or_equal ("0.7.0" ) 
454454class  TorchAoSerializationTest (unittest .TestCase ):
455455    model_name  =  "hf-internal-testing/tiny-flux-pipe" 
456-     quant_method , quant_method_kwargs  =  None , None 
457-     device  =  "cuda" 
458456
459457    def  tearDown (self ):
460458        gc .collect ()
461459        torch .cuda .empty_cache ()
462460
463-     def  get_dummy_model (self , device = None ):
464-         quantization_config  =  TorchAoConfig (self . quant_method , ** self . quant_method_kwargs )
461+     def  get_dummy_model (self , quant_method ,  quant_method_kwargs ,  device = None ):
462+         quantization_config  =  TorchAoConfig (quant_method , ** quant_method_kwargs )
465463        quantized_model  =  FluxTransformer2DModel .from_pretrained (
466464            self .model_name ,
467465            subfolder = "transformer" ,
@@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
497495            "timestep" : timestep ,
498496        }
499497
500-     def  test_original_model_expected_slice (self ):
501-         quantized_model  =  self .get_dummy_model (torch_device )
498+     def  _test_original_model_expected_slice (self ,  quant_method ,  quant_method_kwargs ,  expected_slice ):
499+         quantized_model  =  self .get_dummy_model (quant_method ,  quant_method_kwargs ,  torch_device )
502500        inputs  =  self .get_dummy_tensor_inputs (torch_device )
503501        output  =  quantized_model (** inputs )[0 ]
504502        output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
505-         self .assertTrue (np .allclose (output_slice , self . expected_slice , atol = 1e-3 , rtol = 1e-3 ))
503+         self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
506504
507-     def  check_serialization_expected_slice (self , expected_slice ):
508-         quantized_model  =  self .get_dummy_model (self . device )
505+     def  _check_serialization_expected_slice (self , quant_method ,  quant_method_kwargs ,  expected_slice ,  device ):
506+         quantized_model  =  self .get_dummy_model (quant_method ,  quant_method_kwargs ,  device )
509507
510508        with  tempfile .TemporaryDirectory () as  tmp_dir :
511509            quantized_model .save_pretrained (tmp_dir , safe_serialization = False )
@@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice):
524522        )
525523        self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
526524
527-     def  test_serialization_expected_slice (self ):
528-         self .check_serialization_expected_slice (self .serialized_expected_slice )
529- 
530- 
531- class  TorchAoSerializationINTA8W8Test (TorchAoSerializationTest ):
532-     quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
533-     expected_slice  =  np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
534-     serialized_expected_slice  =  expected_slice 
535-     device  =  "cuda" 
536- 
537- 
538- class  TorchAoSerializationINTA16W8Test (TorchAoSerializationTest ):
539-     quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
540-     expected_slice  =  np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
541-     serialized_expected_slice  =  expected_slice 
542-     device  =  "cuda" 
543- 
544- 
545- class  TorchAoSerializationINTA8W8CPUTest (TorchAoSerializationTest ):
546-     quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
547-     expected_slice  =  np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
548-     serialized_expected_slice  =  expected_slice 
549-     device  =  "cpu" 
550- 
551- 
552- class  TorchAoSerializationINTA16W8CPUTest (TorchAoSerializationTest ):
553-     quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
554-     expected_slice  =  np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
555-     serialized_expected_slice  =  expected_slice 
556-     device  =  "cpu" 
525+     def  test_int_a8w8_cuda (self ):
526+         quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
527+         expected_slice  =  np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
528+         device  =  "cuda" 
529+         self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
530+         self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
531+ 
532+     def  test_int_a16w8_cuda (self ):
533+         quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
534+         expected_slice  =  np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
535+         device  =  "cuda" 
536+         self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
537+         self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
538+ 
539+     def  test_int_a8w8_cpu (self ):
540+         quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
541+         expected_slice  =  np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
542+         device  =  "cpu" 
543+         self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
544+         self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
545+ 
546+     def  test_int_a16w8_cpu (self ):
547+         quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
548+         expected_slice  =  np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
549+         device  =  "cpu" 
550+         self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
551+         self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
557552
558553
559554# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners 
0 commit comments