|
36 | 36 | nightly, |
37 | 37 | require_torch, |
38 | 38 | require_torch_gpu, |
39 | | - require_torchao_version_greater, |
| 39 | + require_torchao_version_greater_or_equal, |
40 | 40 | slow, |
41 | 41 | torch_device, |
42 | 42 | ) |
@@ -74,13 +74,12 @@ def forward(self, input, *args, **kwargs): |
74 | 74 |
|
75 | 75 | if is_torchao_available(): |
76 | 76 | from torchao.dtypes import AffineQuantizedTensor |
77 | | - from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType |
78 | 77 | from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor |
79 | 78 |
|
80 | 79 |
|
81 | 80 | @require_torch |
82 | 81 | @require_torch_gpu |
83 | | -@require_torchao_version_greater("0.6.0") |
| 82 | +@require_torchao_version_greater_or_equal("0.7.0") |
84 | 83 | class TorchAoConfigTest(unittest.TestCase): |
85 | 84 | def test_to_dict(self): |
86 | 85 | """ |
@@ -125,7 +124,7 @@ def test_repr(self): |
125 | 124 | # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners |
126 | 125 | @require_torch |
127 | 126 | @require_torch_gpu |
128 | | -@require_torchao_version_greater("0.6.0") |
| 127 | +@require_torchao_version_greater_or_equal("0.7.0") |
129 | 128 | class TorchAoTest(unittest.TestCase): |
130 | 129 | def tearDown(self): |
131 | 130 | gc.collect() |
@@ -276,7 +275,6 @@ def test_int4wo_quant_bfloat16_conversion(self): |
276 | 275 | self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
277 | 276 | self.assertEqual(weight.quant_min, 0) |
278 | 277 | self.assertEqual(weight.quant_max, 15) |
279 | | - self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) |
280 | 278 |
|
281 | 279 | def test_offload(self): |
282 | 280 | """ |
@@ -432,7 +430,7 @@ def test_wrong_config(self): |
432 | 430 | # This class is not to be run as a test by itself. See the tests that follow this class |
433 | 431 | @require_torch |
434 | 432 | @require_torch_gpu |
435 | | -@require_torchao_version_greater("0.6.0") |
| 433 | +@require_torchao_version_greater_or_equal("0.7.0") |
436 | 434 | class TorchAoSerializationTest(unittest.TestCase): |
437 | 435 | model_name = "hf-internal-testing/tiny-flux-pipe" |
438 | 436 | quant_method, quant_method_kwargs = None, None |
@@ -541,7 +539,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): |
541 | 539 | # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners |
542 | 540 | @require_torch |
543 | 541 | @require_torch_gpu |
544 | | -@require_torchao_version_greater("0.6.0") |
| 542 | +@require_torchao_version_greater_or_equal("0.7.0") |
545 | 543 | @slow |
546 | 544 | @nightly |
547 | 545 | class SlowTorchAoTests(unittest.TestCase): |
|
0 commit comments