Skip to content

Commit c1ea892

Browse files
committed
Skip test if onnx version > 1.18
Signed-off-by: ajrasane <[email protected]>
1 parent 3420d48 commit c1ea892

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
datasets>=2.14.5
2+
onnx==1.18.0
23
torch==2.6.0
34
transformers==4.49.0

tests/_test_utils/import_helper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import importlib.metadata
1617
import shutil
1718

1819
import pytest
20+
from packaging import version
1921

2022

2123
def skip_if_no_tensorrt():
@@ -73,3 +75,18 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool
7375

7476
if mamba_required and not has_mamba:
7577
pytest.skip("Mamba required for Megatron test", allow_module_level=True)
78+
79+
80+
def skip_if_onnx_version_below_1_19():
81+
package_name = "onnx"
82+
required_version = "1.19.0"
83+
84+
try:
85+
installed_version = importlib.metadata.version(package_name)
86+
except importlib.metadata.PackageNotFoundError:
87+
pytest.skip(f"{package_name} is not installed")
88+
89+
if version.parse(installed_version) < version.parse(required_version):
90+
pytest.skip(
91+
f"{package_name} version {installed_version} is less than required {required_version}"
92+
)

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from functools import partial
2121

2222
import torch
23-
from _test_utils.import_helper import skip_if_no_libcudnn
23+
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_below_1_19
2424
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch_quantization.quantize_common import get_awq_config
2626

@@ -40,6 +40,8 @@
4040

4141

4242
def test_int4_awq(tmp_path):
43+
skip_if_onnx_version_below_1_19()
44+
4345
def _forward_loop(model, dataloader):
4446
"""Forward loop for calibration."""
4547
for data in dataloader:
@@ -114,6 +116,7 @@ def _forward_loop(model, dataloader):
114116

115117

116118
def test_int4_awq_cuda(tmp_path):
119+
skip_if_onnx_version_below_1_19()
117120
skip_if_no_libcudnn()
118121
block_size = 128
119122

0 commit comments

Comments
 (0)