Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/onnx_ptq/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def forward_loop(model):
return quantized_model


def get_model_input_shape(model_name):
def get_model_input_shape(model_name, batch_size):
"""Get the input shape from timm model configuration."""
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
data_config = timm.data.resolve_model_data_config(model)
input_size = data_config["input_size"]
return (1, *tuple(input_size)) # Add batch dimension
return (batch_size, *tuple(input_size)) # Add batch dimension


def main():
Expand Down Expand Up @@ -119,11 +119,17 @@ def main():
default=512,
help="Number of images to use in calibration [1-512]",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for calibration and ONNX model export.",
)

args = parser.parse_args()

# Get input shape from model config
input_shape = get_model_input_shape(args.timm_model_name)
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)

# Create model and move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
1 change: 1 addition & 0 deletions examples/windows/onnx_ptq/genai_llm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
datasets>=2.14.5
onnx==1.18.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix version mismatch with PR objective (onnx 1.19).

This example pins onnx==1.18.0 while the PR upgrades repo/tooling to 1.19.0 and gates tests on >=1.19. Align to avoid feature/API skew (e.g., FP4/INT4 utilities).

-onnx==1.18.0
+onnx==1.19.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
onnx==1.18.0
++ b/examples/windows/onnx_ptq/genai_llm/requirements.txt
@@ -1,3 +1,3 @@
onnx==1.19.0
🤖 Prompt for AI Agents
In examples/windows/onnx_ptq/genai_llm/requirements.txt around line 2, the file
pins onnx==1.18.0 which mismatches the repo/test expectation of onnx>=1.19.0;
update the requirement to onnx==1.19.0 (or onnx>=1.19.0 if a range is preferred)
so the example aligns with the PR tooling/tests and avoids API/feature skew
(e.g., FP4/INT4 utilities).

torch==2.6.0
transformers==4.49.0
78 changes: 65 additions & 13 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import onnx_graphsurgeon as gs
import torch
from onnx import numpy_helper
from onnx.reference.custom_element_types import float8e4m3fn

from modelopt.onnx import utils
from modelopt.onnx.logging_config import logger
Expand All @@ -50,6 +49,7 @@
onnx_dtype_map = {
"BFloat16": onnx.TensorProto.BFLOAT16,
"Float": onnx.TensorProto.FLOAT,
"Float4": onnx.TensorProto.FLOAT4E2M1,
"Float8": onnx.TensorProto.FLOAT8E4M3FN,
"Half": onnx.TensorProto.FLOAT16,
"INT8": onnx.TensorProto.INT8,
Expand Down Expand Up @@ -592,7 +592,7 @@ def _convert_weight(
zp_array = zp_array.reshape(*reshape_dims)

# Convert to INT8/FP8
if zp_array.dtype == float8e4m3fn:
if zp_array.dtype == onnx_dtype_map["Float8"]:
scaled = np.asarray(weight_array / scale_array) + zp_array
else:
scaled = np.asarray((weight_array / scale_array).round())
Expand All @@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
if torch.cuda.is_available():
array_f32_t = array_f32_t.cuda()
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8)
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")]))
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
return array_f8


def _cast_fp4(array: np.ndarray) -> np.ndarray:
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.

Note: The first dimension of the array must be divisible by 2
as two FP4 values are packed into a single byte.
"""
array_f32_t = torch.from_numpy(array)
array_f32_t_shape = array_f32_t.shape
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
if torch.cuda.is_available():
array_f32_t = array_f32_t.cuda()
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")]))
array_f4_t = array_f4_t.flatten()
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
return array_f4


Expand Down Expand Up @@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)

# Create and update new weight tensor
if zp_array.dtype == float8e4m3fn:
if zp_array.dtype == onnx_dtype_map["Float8"]:
new_weight = _create_fp8_tensor(scaled, weight_name)
logger.debug(f"Converted {weight_name} to FP8")
else:
Expand Down Expand Up @@ -920,6 +929,10 @@ def quantize_weights_to_int4(
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}"
reshape_node_output = reshape_node.output[0]

# Remove constant node from reshape node
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)

# Get the shape of the output of the reshape node
reshape_output_value_info = value_info_map.get(reshape_node_output)
if reshape_output_value_info is not None:
Expand All @@ -937,12 +950,17 @@ def quantize_weights_to_int4(
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
scale = scale.reshape(scale_shape)
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
# reshape_node.input = []
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"

# Remove unnecessary Cast node
cast_node = reshape_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
nodes_to_remove.append(cast_node.name)
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]

# Transpose weights and scales if present
if reshape_child_nodes[0].op_type == "Transpose":
transpose_node = reshape_child_nodes[0]
if cast_child_nodes[0].op_type == "Transpose":
transpose_node = cast_child_nodes[0]
nodes_to_remove.append(transpose_node.name)
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
perm = None
Expand All @@ -959,7 +977,7 @@ def quantize_weights_to_int4(
)
matmul_node = transpose_child_nodes[0]
else:
matmul_node = reshape_child_nodes[0]
matmul_node = cast_child_nodes[0]
assert matmul_node.op_type in ["MatMul", "Gemm"], (
f"Expected MatMul or Gemm node for {node.name}"
)
Expand Down Expand Up @@ -990,6 +1008,21 @@ def quantize_weights_to_int4(
initializer_map[weight_name].CopyFrom(weights_int4_onnx)
logger.debug(f"Converted {weight_name} to INT4 precision")

def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input

# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

Comment on lines +1011 to +1025
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Pre-quant Mul detection is wrong; the naming doesn’t match inserted nodes

The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger.

Make detection structural: Mul with one initializer or Constant input.

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False

Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers.


🏁 Script executed:

#!/bin/bash

echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true

echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true

echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

Length of output: 28239


Use structural detection for pre-quant Mul in qdq_utils.py
The current is_pre_quant_scale_node looks for “_pre_quant_scale” in the input name, but the inserted scale tensors are named with “_awq_scale”, so this code never matches and the Cast isn’t removed. Replace it with a structural check—i.e. a Mul with exactly two inputs, one coming from an initializer or Constant node:

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input
# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
# A pre-quantization scale is always a Mul with one constant input
if node.op_type != "Mul" or len(node.input) != 2:
return False
for inp in node.input:
# Check if one of the inputs is a graph initializer
if inp in initializer_map:
return True
# Or produced by a Constant node
prod = tensor_producer_map.get(inp)
if prod is not None and prod.op_type == "Constant":
return True
return False
# Remove unnecessary Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

# Remove transpose and reshape nodes
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
graph.node.clear()
Expand All @@ -1004,7 +1037,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
for node in graph.node:
if node.op_type == "Cast":
# Skip Cast nodes that are part of normalization layers and outputs
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast":
if "norm/Cast" in node.name and is_fp32_cast(node):
continue
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
Expand Down Expand Up @@ -1099,7 +1132,13 @@ def quantize_weights_to_mxfp8(
# Expand block array so that it can be broadcasted with weight
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name)
weights_e4m3 = onnx.helper.make_tensor(
name=weight_name,
data_type=onnx_dtype_map["Float8"],
dims=[*scaled_weight.shape],
vals=_cast_fp8(scaled_weight).tobytes(),
raw=True,
)
initializer_map[weight_name].CopyFrom(weights_e4m3)
logger.debug(f"Converted {weight_name} to MXFP8")

Expand Down Expand Up @@ -1181,11 +1220,24 @@ def _add_input_value_info(graph, tensor_proto):
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"

# Create TensorProto for initializers
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name)
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
Comment on lines +1223 to +1229
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

FP4 initializer dims should reflect packing along the last axis

After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly.

-    w_f4_proto = onnx.helper.make_tensor(
-        name=w_f4_name,
-        data_type=onnx_dtype_map["Float4"],
-        dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
-        vals=w_f4.tobytes(),
-        raw=True,
-    )
+    w_f4_proto = onnx.helper.make_tensor(
+        name=w_f4_name,
+        data_type=onnx_dtype_map["Float4"],
+        dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+        vals=w_f4.tobytes(),
+        raw=True,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
vals=w_f4.tobytes(),
raw=True,
)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 1219 to 1225, the FP4
initializer currently doubles the first dimension but FP4 packing was changed to
pack along the last axis; update the dims to reflect packing along the last axis
by replacing dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]] with
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2] (or equivalent list/tuple
construction) so the last dimension is doubled instead of the first.

sw_f32_per_tensor_proto = onnx.numpy_helper.from_array(
sw_f32_per_tensor, sw_f32_per_tensor_name
)
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name)
sw_f8_per_block_proto = onnx.helper.make_tensor(
name=sw_f8_per_block_name,
data_type=onnx_dtype_map["Float8"],
dims=[*sw_f8_per_block.shape],
vals=sw_f8_per_block.tobytes(),
raw=True,
)

# Add ValueInfo for the initializers if not present
_add_input_value_info(graph, w_f4_proto)
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ def get_onnx_bytes_and_metadata(
except StopIteration:
param_dtype = torch.float32
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
if is_mxfp8_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet"
if is_mxfp8_quantized(model) or is_int4_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
"ml_dtypes", # for bfloat16 conversion
"onnx-graphsurgeon",
"onnx~=1.18.0",
"onnxconverter-common",
"onnx~=1.19.0",
"onnxconverter-common~=1.16.0",
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
"onnxruntime-directml==1.20.0; platform_system == 'Windows'",
Expand Down
17 changes: 17 additions & 0 deletions tests/_test_utils/import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.metadata
import shutil

import pytest
from packaging import version


def skip_if_no_tensorrt():
Expand Down Expand Up @@ -73,3 +75,18 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool

if mamba_required and not has_mamba:
pytest.skip("Mamba required for Megatron test", allow_module_level=True)


def skip_if_onnx_version_above_1_18():
package_name = "onnx"
required_version = "1.18.0"

try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed")

if version.parse(installed_version) > version.parse(required_version):
pytest.skip(
f"{package_name} version {installed_version} is less than required {required_version}"
)
Comment on lines +80 to +92
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Skip reason text is wrong; variable name misleads; add allow_module_level for consistency.

Condition skips when ONNX > 1.18, but the message says “less than required”. Rename to reflect max supported, fix message, and pass allow_module_level=True.

 def skip_if_onnx_version_above_1_18():
     package_name = "onnx"
-    required_version = "1.18.0"
+    max_supported_version = "1.18.0"

     try:
         installed_version = importlib.metadata.version(package_name)
     except importlib.metadata.PackageNotFoundError:
-        pytest.skip(f"{package_name} is not installed")
+        pytest.skip(f"{package_name} is not installed", allow_module_level=True)

-    if version.parse(installed_version) > version.parse(required_version):
+    if version.parse(installed_version) > version.parse(max_supported_version):
         pytest.skip(
-            f"{package_name} version {installed_version} is less than required {required_version}"
+            f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
+            allow_module_level=True,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
required_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed")
if version.parse(installed_version) > version.parse(required_version):
pytest.skip(
f"{package_name} version {installed_version} is less than required {required_version}"
)
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
max_supported_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed", allow_module_level=True)
if version.parse(installed_version) > version.parse(max_supported_version):
pytest.skip(
f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
allow_module_level=True,
)
🤖 Prompt for AI Agents
In tests/_test_utils/import_helper.py around lines 80 to 92, the helper misnames
the version variable and logs an incorrect skip message and omits
allow_module_level; rename required_version to max_supported_version (or
similar), update the skip message to say the installed ONNX version is greater
than the max supported (include installed_version and max_supported_version),
and call pytest.skip(..., allow_module_level=True) when skipping due to version
being above the supported maximum.

5 changes: 4 additions & 1 deletion tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import partial

import torch
from _test_utils.import_helper import skip_if_no_libcudnn
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
from _test_utils.torch_quantization.quantize_common import get_awq_config

Expand All @@ -40,6 +40,8 @@


def test_int4_awq(tmp_path):
skip_if_onnx_version_above_1_18()

def _forward_loop(model, dataloader):
"""Forward loop for calibration."""
for data in dataloader:
Expand Down Expand Up @@ -114,6 +116,7 @@ def _forward_loop(model, dataloader):


def test_int4_awq_cuda(tmp_path):
skip_if_onnx_version_above_1_18()
skip_if_no_libcudnn()
block_size = 128

Expand Down
Loading
Loading