@@ -447,21 +447,19 @@ def test_wrong_config(self):
447447 self .get_dummy_components (TorchAoConfig ("int42" ))
448448
449449
450- # This class is not to be run as a test by itself. See the tests that follow this class
450+ # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
451451@require_torch
452452@require_torch_gpu
453453@require_torchao_version_greater_or_equal ("0.7.0" )
454454class TorchAoSerializationTest (unittest .TestCase ):
455455 model_name = "hf-internal-testing/tiny-flux-pipe"
456- quant_method , quant_method_kwargs = None , None
457- device = "cuda"
458456
459457 def tearDown (self ):
460458 gc .collect ()
461459 torch .cuda .empty_cache ()
462460
463- def get_dummy_model (self , device = None ):
464- quantization_config = TorchAoConfig (self . quant_method , ** self . quant_method_kwargs )
461+ def get_dummy_model (self , quant_method , quant_method_kwargs , device = None ):
462+ quantization_config = TorchAoConfig (quant_method , ** quant_method_kwargs )
465463 quantized_model = FluxTransformer2DModel .from_pretrained (
466464 self .model_name ,
467465 subfolder = "transformer" ,
@@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
497495 "timestep" : timestep ,
498496 }
499497
500- def test_original_model_expected_slice (self ):
501- quantized_model = self .get_dummy_model (torch_device )
498+ def _test_original_model_expected_slice (self , quant_method , quant_method_kwargs , expected_slice ):
499+ quantized_model = self .get_dummy_model (quant_method , quant_method_kwargs , torch_device )
502500 inputs = self .get_dummy_tensor_inputs (torch_device )
503501 output = quantized_model (** inputs )[0 ]
504502 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
505- self .assertTrue (np .allclose (output_slice , self . expected_slice , atol = 1e-3 , rtol = 1e-3 ))
503+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
506504
507- def check_serialization_expected_slice (self , expected_slice ):
508- quantized_model = self .get_dummy_model (self . device )
505+ def _check_serialization_expected_slice (self , quant_method , quant_method_kwargs , expected_slice , device ):
506+ quantized_model = self .get_dummy_model (quant_method , quant_method_kwargs , device )
509507
510508 with tempfile .TemporaryDirectory () as tmp_dir :
511509 quantized_model .save_pretrained (tmp_dir , safe_serialization = False )
@@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice):
524522 )
525523 self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
526524
527- def test_serialization_expected_slice (self ):
528- self .check_serialization_expected_slice (self .serialized_expected_slice )
529-
530-
531- class TorchAoSerializationINTA8W8Test (TorchAoSerializationTest ):
532- quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
533- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
534- serialized_expected_slice = expected_slice
535- device = "cuda"
536-
537-
538- class TorchAoSerializationINTA16W8Test (TorchAoSerializationTest ):
539- quant_method , quant_method_kwargs = "int8_weight_only" , {}
540- expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
541- serialized_expected_slice = expected_slice
542- device = "cuda"
543-
544-
545- class TorchAoSerializationINTA8W8CPUTest (TorchAoSerializationTest ):
546- quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
547- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
548- serialized_expected_slice = expected_slice
549- device = "cpu"
550-
551-
552- class TorchAoSerializationINTA16W8CPUTest (TorchAoSerializationTest ):
553- quant_method , quant_method_kwargs = "int8_weight_only" , {}
554- expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
555- serialized_expected_slice = expected_slice
556- device = "cpu"
525+ def test_int_a8w8_cuda (self ):
526+ quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
527+ expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
528+ device = "cuda"
529+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
530+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
531+
532+ def test_int_a16w8_cuda (self ):
533+ quant_method , quant_method_kwargs = "int8_weight_only" , {}
534+ expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
535+ device = "cuda"
536+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
537+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
538+
539+ def test_int_a8w8_cpu (self ):
540+ quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
541+ expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
542+ device = "cpu"
543+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
544+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
545+
546+ def test_int_a16w8_cpu (self ):
547+ quant_method , quant_method_kwargs = "int8_weight_only" , {}
548+ expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
549+ device = "cpu"
550+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
551+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
557552
558553
559554# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
0 commit comments