@@ -447,21 +447,19 @@ def test_wrong_config(self):
447
447
self .get_dummy_components (TorchAoConfig ("int42" ))
448
448
449
449
450
- # This class is not to be run as a test by itself. See the tests that follow this class
450
+ # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
451
451
@require_torch
452
452
@require_torch_gpu
453
453
@require_torchao_version_greater_or_equal ("0.7.0" )
454
454
class TorchAoSerializationTest (unittest .TestCase ):
455
455
model_name = "hf-internal-testing/tiny-flux-pipe"
456
- quant_method , quant_method_kwargs = None , None
457
- device = "cuda"
458
456
459
457
def tearDown (self ):
460
458
gc .collect ()
461
459
torch .cuda .empty_cache ()
462
460
463
- def get_dummy_model (self , device = None ):
464
- quantization_config = TorchAoConfig (self . quant_method , ** self . quant_method_kwargs )
461
+ def get_dummy_model (self , quant_method , quant_method_kwargs , device = None ):
462
+ quantization_config = TorchAoConfig (quant_method , ** quant_method_kwargs )
465
463
quantized_model = FluxTransformer2DModel .from_pretrained (
466
464
self .model_name ,
467
465
subfolder = "transformer" ,
@@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
497
495
"timestep" : timestep ,
498
496
}
499
497
500
- def test_original_model_expected_slice (self ):
501
- quantized_model = self .get_dummy_model (torch_device )
498
+ def _test_original_model_expected_slice (self , quant_method , quant_method_kwargs , expected_slice ):
499
+ quantized_model = self .get_dummy_model (quant_method , quant_method_kwargs , torch_device )
502
500
inputs = self .get_dummy_tensor_inputs (torch_device )
503
501
output = quantized_model (** inputs )[0 ]
504
502
output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
505
- self .assertTrue (np .allclose (output_slice , self . expected_slice , atol = 1e-3 , rtol = 1e-3 ))
503
+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
506
504
507
- def check_serialization_expected_slice (self , expected_slice ):
508
- quantized_model = self .get_dummy_model (self . device )
505
+ def _check_serialization_expected_slice (self , quant_method , quant_method_kwargs , expected_slice , device ):
506
+ quantized_model = self .get_dummy_model (quant_method , quant_method_kwargs , device )
509
507
510
508
with tempfile .TemporaryDirectory () as tmp_dir :
511
509
quantized_model .save_pretrained (tmp_dir , safe_serialization = False )
@@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice):
524
522
)
525
523
self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
526
524
527
- def test_serialization_expected_slice (self ):
528
- self .check_serialization_expected_slice (self .serialized_expected_slice )
529
-
530
-
531
- class TorchAoSerializationINTA8W8Test (TorchAoSerializationTest ):
532
- quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
533
- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
534
- serialized_expected_slice = expected_slice
535
- device = "cuda"
536
-
537
-
538
- class TorchAoSerializationINTA16W8Test (TorchAoSerializationTest ):
539
- quant_method , quant_method_kwargs = "int8_weight_only" , {}
540
- expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
541
- serialized_expected_slice = expected_slice
542
- device = "cuda"
543
-
544
-
545
- class TorchAoSerializationINTA8W8CPUTest (TorchAoSerializationTest ):
546
- quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
547
- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
548
- serialized_expected_slice = expected_slice
549
- device = "cpu"
550
-
551
-
552
- class TorchAoSerializationINTA16W8CPUTest (TorchAoSerializationTest ):
553
- quant_method , quant_method_kwargs = "int8_weight_only" , {}
554
- expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
555
- serialized_expected_slice = expected_slice
556
- device = "cpu"
525
+ def test_int_a8w8_cuda (self ):
526
+ quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
527
+ expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
528
+ device = "cuda"
529
+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
530
+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
531
+
532
+ def test_int_a16w8_cuda (self ):
533
+ quant_method , quant_method_kwargs = "int8_weight_only" , {}
534
+ expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
535
+ device = "cuda"
536
+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
537
+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
538
+
539
+ def test_int_a8w8_cpu (self ):
540
+ quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
541
+ expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
542
+ device = "cpu"
543
+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
544
+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
545
+
546
+ def test_int_a16w8_cpu (self ):
547
+ quant_method , quant_method_kwargs = "int8_weight_only" , {}
548
+ expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
549
+ device = "cpu"
550
+ self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
551
+ self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
557
552
558
553
559
554
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
0 commit comments