Skip to content

Commit 02c777c

Browse files
authored
[tests] Refactor TorchAO serialization fast tests (#10271)
refactor
1 parent 6a970a4 commit 02c777c

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -447,21 +447,19 @@ def test_wrong_config(self):
447447
self.get_dummy_components(TorchAoConfig("int42"))
448448

449449

450-
# This class is not to be run as a test by itself. See the tests that follow this class
450+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
451451
@require_torch
452452
@require_torch_gpu
453453
@require_torchao_version_greater_or_equal("0.7.0")
454454
class TorchAoSerializationTest(unittest.TestCase):
455455
model_name = "hf-internal-testing/tiny-flux-pipe"
456-
quant_method, quant_method_kwargs = None, None
457-
device = "cuda"
458456

459457
def tearDown(self):
460458
gc.collect()
461459
torch.cuda.empty_cache()
462460

463-
def get_dummy_model(self, device=None):
464-
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
461+
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
462+
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
465463
quantized_model = FluxTransformer2DModel.from_pretrained(
466464
self.model_name,
467465
subfolder="transformer",
@@ -497,15 +495,15 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
497495
"timestep": timestep,
498496
}
499497

500-
def test_original_model_expected_slice(self):
501-
quantized_model = self.get_dummy_model(torch_device)
498+
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
499+
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
502500
inputs = self.get_dummy_tensor_inputs(torch_device)
503501
output = quantized_model(**inputs)[0]
504502
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
505-
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
503+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
506504

507-
def check_serialization_expected_slice(self, expected_slice):
508-
quantized_model = self.get_dummy_model(self.device)
505+
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
506+
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
509507

510508
with tempfile.TemporaryDirectory() as tmp_dir:
511509
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
@@ -524,36 +522,33 @@ def check_serialization_expected_slice(self, expected_slice):
524522
)
525523
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
526524

527-
def test_serialization_expected_slice(self):
528-
self.check_serialization_expected_slice(self.serialized_expected_slice)
529-
530-
531-
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
532-
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
533-
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
534-
serialized_expected_slice = expected_slice
535-
device = "cuda"
536-
537-
538-
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
539-
quant_method, quant_method_kwargs = "int8_weight_only", {}
540-
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
541-
serialized_expected_slice = expected_slice
542-
device = "cuda"
543-
544-
545-
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
546-
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
547-
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
548-
serialized_expected_slice = expected_slice
549-
device = "cpu"
550-
551-
552-
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
553-
quant_method, quant_method_kwargs = "int8_weight_only", {}
554-
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
555-
serialized_expected_slice = expected_slice
556-
device = "cpu"
525+
def test_int_a8w8_cuda(self):
526+
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
527+
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
528+
device = "cuda"
529+
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
530+
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
531+
532+
def test_int_a16w8_cuda(self):
533+
quant_method, quant_method_kwargs = "int8_weight_only", {}
534+
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
535+
device = "cuda"
536+
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
537+
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
538+
539+
def test_int_a8w8_cpu(self):
540+
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
541+
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
542+
device = "cpu"
543+
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
544+
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
545+
546+
def test_int_a16w8_cpu(self):
547+
quant_method, quant_method_kwargs = "int8_weight_only", {}
548+
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
549+
device = "cpu"
550+
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
551+
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
557552

558553

559554
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners

0 commit comments

Comments
 (0)