Skip to content

Commit 03d342d

Browse files
committed
Skip int4 test if onnx version is > 1.18
1 parent ed78802 commit 03d342d

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/_test_utils/import_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool
7777
pytest.skip("Mamba required for Megatron test", allow_module_level=True)
7878

7979

80-
def skip_if_onnx_version_below_1_19():
80+
def skip_if_onnx_version_above_1_18():
8181
package_name = "onnx"
82-
required_version = "1.19.0"
82+
required_version = "1.18.0"
8383

8484
try:
8585
installed_version = importlib.metadata.version(package_name)
8686
except importlib.metadata.PackageNotFoundError:
8787
pytest.skip(f"{package_name} is not installed")
8888

89-
if version.parse(installed_version) < version.parse(required_version):
89+
if version.parse(installed_version) > version.parse(required_version):
9090
pytest.skip(
9191
f"{package_name} version {installed_version} is less than required {required_version}"
9292
)

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from functools import partial
2121

2222
import torch
23-
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_below_1_19
23+
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18
2424
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch_quantization.quantize_common import get_awq_config
2626

@@ -40,7 +40,7 @@
4040

4141

4242
def test_int4_awq(tmp_path):
43-
skip_if_onnx_version_below_1_19()
43+
skip_if_onnx_version_above_1_18()
4444

4545
def _forward_loop(model, dataloader):
4646
"""Forward loop for calibration."""
@@ -116,7 +116,7 @@ def _forward_loop(model, dataloader):
116116

117117

118118
def test_int4_awq_cuda(tmp_path):
119-
skip_if_onnx_version_below_1_19()
119+
skip_if_onnx_version_above_1_18()
120120
skip_if_no_libcudnn()
121121
block_size = 128
122122

0 commit comments

Comments
 (0)