Skip to content

Commit 0f812cd

Browse files
committed
bump min torchao version to 0.7.0
1 parent 63cdf9c commit 0f812cd

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def decorator(test_case):
490490
return decorator
491491

492492

493-
def require_torchao_version_greater(torchao_version):
493+
def require_torchao_version_greater_or_equal(torchao_version):
494494
def decorator(test_case):
495495
correct_torchao_version = is_torchao_available() and version.parse(
496496
version.parse(importlib.metadata.version("torchao")).base_version
497-
) > version.parse(torchao_version)
497+
) >= version.parse(torchao_version)
498498
return unittest.skipUnless(
499499
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
500500
)(test_case)

tests/quantization/torchao/test_torchao.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
nightly,
3737
require_torch,
3838
require_torch_gpu,
39-
require_torchao_version_greater,
39+
require_torchao_version_greater_or_equal,
4040
slow,
4141
torch_device,
4242
)
@@ -74,13 +74,12 @@ def forward(self, input, *args, **kwargs):
7474

7575
if is_torchao_available():
7676
from torchao.dtypes import AffineQuantizedTensor
77-
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
7877
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
7978

8079

8180
@require_torch
8281
@require_torch_gpu
83-
@require_torchao_version_greater("0.6.0")
82+
@require_torchao_version_greater_or_equal("0.7.0")
8483
class TorchAoConfigTest(unittest.TestCase):
8584
def test_to_dict(self):
8685
"""
@@ -125,7 +124,7 @@ def test_repr(self):
125124
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
126125
@require_torch
127126
@require_torch_gpu
128-
@require_torchao_version_greater("0.6.0")
127+
@require_torchao_version_greater_or_equal("0.7.0")
129128
class TorchAoTest(unittest.TestCase):
130129
def tearDown(self):
131130
gc.collect()
@@ -276,7 +275,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
276275
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
277276
self.assertEqual(weight.quant_min, 0)
278277
self.assertEqual(weight.quant_max, 15)
279-
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
280278

281279
def test_offload(self):
282280
"""
@@ -432,7 +430,7 @@ def test_wrong_config(self):
432430
# This class is not to be run as a test by itself. See the tests that follow this class
433431
@require_torch
434432
@require_torch_gpu
435-
@require_torchao_version_greater("0.6.0")
433+
@require_torchao_version_greater_or_equal("0.7.0")
436434
class TorchAoSerializationTest(unittest.TestCase):
437435
model_name = "hf-internal-testing/tiny-flux-pipe"
438436
quant_method, quant_method_kwargs = None, None
@@ -541,7 +539,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
541539
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
542540
@require_torch
543541
@require_torch_gpu
544-
@require_torchao_version_greater("0.6.0")
542+
@require_torchao_version_greater_or_equal("0.7.0")
545543
@slow
546544
@nightly
547545
class SlowTorchAoTests(unittest.TestCase):

0 commit comments

Comments
 (0)