Skip to content

Conversation

ajrasane
Copy link
Contributor

@ajrasane ajrasane commented Sep 5, 2025

What does this PR do?

Type of change: Version upgrade

Overview:

  • Upgraded onnx to v1.19.0
  • Upgraded onnxconverter-common to v1.16.0
  • Added batch_size to torch_quant_to_onnx example
  • Updated fp4 casting logic
  • Updated int4 weight conversion logic and postprocessing
  • Updated tests based on new logic

Testing

All unit and integrations tests pass

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Summary by CodeRabbit

  • New Features

    • FP4 (Float4) support for ONNX quantization/export.
    • New CLI option --batch_size to control model input shapes and calibration loading.
    • FP16/BF16 export also triggered for INT4-quantized models.
  • Refactor

    • Simplified FP8/FP4 casting, packing, and QDQ→DQ transformations; reduced redundant casts and improved graph handling.
  • Chores

    • ONNX extras bumped (onnx ~=1.19.0, onnxconverter-common ~=1.16.0); example requirements updated to include ONNX.
  • Tests

    • Unit and GPU tests updated for FP4/FP8; added ONNX-version gating and skips.

Copy link

coderabbitai bot commented Sep 5, 2025

Walkthrough

Adds batch-size plumbing to the ONNX PTQ example; introduces FP4 dtype support and refactors FP8/FP4 casting and initializer creation in ONNX QDQ utilities; expands Torch ONNX conversion trigger to include INT4; updates ONNX extras and example requirements; adds ONNX-version-based test gating and adapts unit/gpu tests and unit test data/layouts.

Changes

Cohort / File(s) Summary of Changes
ONNX PTQ example: batch size plumbing
examples/onnx_ptq/torch_quant_to_onnx.py
Added --batch_size CLI; get_model_input_shape(model_name, batch_size) returns shapes with leading batch dim; load_calibration_data(model_name, data_size, batch_size, device) uses that batch size; main flow passes batch into input shape, calibration loader, and ONNX export.
ONNX QDQ utilities: FP4 support and FP8/FP4 refactor
modelopt/onnx/quantization/qdq_utils.py
Added Float4 mapping (FLOAT4E2M1) in onnx_dtype_map. _cast_fp8 returns flat uint8 buffer (via PyTorch); _cast_fp4 packs two 4-bit values into bytes and returns flat uint8 with assertions. Initializers for FP4/FP8 created with onnx.helper.make_tensor and explicit dtypes/raw data. Updated QDQ→DQ, FP4QDQ→2DQ, weight conversion, pre-quant-scale Cast detection/removal, and related dtype/shape checks.
Torch ONNX deploy utility
modelopt/torch/_deploy/utils/torch_onnx.py
In get_onnx_bytes_and_metadata, FP16/BF16 conversion trigger expanded to run for INT4-quantized models as well as MXFP8; assertion message updated to mention MXFP8/INT4 mixed precision.
Dependencies: ONNX extras
setup.py
ONNX extras updated: onnx~=1.19.0; added/pinned onnxconverter-common~=1.16.0.
Example requirements (Windows)
examples/windows/onnx_ptq/genai_llm/requirements.txt
Added onnx==1.18.0 line to the example requirements.
Test utilities: ONNX version gating helper
tests/_test_utils/import_helper.py
Added helper skip_if_onnx_version_above_1_18() that uses importlib.metadata and packaging.version to skip tests when installed ONNX > 1.18.0 (or skip if ONNX missing).
GPU tests: ONNX-version gated
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
Imported and invoked skip_if_onnx_version_above_1_18() at test start for test_int4_awq and test_int4_awq_cuda to skip when ONNX > 1.18.0.
Unit tests: QDQ utils adjustments
tests/unit/onnx/test_qdq_utils.py
Graph construction changed: Reshape consumes a Constant for shape; added Cast between Reshape and Transpose; removed reshape-shape initializer. Test arrays for FP4/FP8 made 2D and expected outputs updated (FP4 expected as plain uint8 packed), and Cast-handling expectations relaxed/adjusted.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant CLI as torch_quant_to_onnx.py
  participant Model as Torch Model
  participant Loader as Calibration DataLoader
  participant Export as ONNX Export

  User->>CLI: run with --batch_size N
  CLI->>CLI: get_model_input_shape(model, N)
  CLI->>Loader: load_calibration_data(model, data_size, N, device)
  Loader-->>CLI: DataLoader (batch=N)
  alt calibration enabled
    CLI->>Model: calibrate using Loader
  end
  CLI->>Export: export ONNX with input shape [N,...]
  Export-->>User: ONNX model (batched)
Loading
sequenceDiagram
  autonumber
  participant Graph as qdq_utils
  participant Weights as Weight Tensor
  participant CastRoutine as _cast_fp4/_cast_fp8
  participant Init as onnx.helper.make_tensor

  Graph->>Weights: select quantizable weights
  alt FP4 target
    Weights->>CastRoutine: _cast_fp4 (pack 2×4-bit → uint8)
    CastRoutine-->>Init: raw_data (packed uint8), dtype=Float4
    Init->>Graph: create FP4 initializer
  else FP8 target
    Weights->>CastRoutine: _cast_fp8 (flat uint8)
    CastRoutine-->>Init: raw_data (uint8), dtype=Float8
    Init->>Graph: create FP8 initializer
  end
  Graph->>Graph: detect/remove pre-quant-scale Casts
  Graph->>Graph: rewrite QDQ→DQ / FP4QDQ→2DQ nodes and value_info
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title “Upgrade to ONNX 1.19.0” is a concise, single-sentence summary that accurately highlights the main change—upgrading the ONNX dependency—and aligns with the developer’s primary objective without extraneous details.
Description Check ✅ Passed The description clearly outlines the version upgrades, enumerates the added batch_size support, FP4/FP8 and INT4 logic updates, test adjustments, and confirms all tests pass, making it directly relevant to the changeset.

Poem

I nibble bytes like clover leaves so bright,
Two nibbles in a bite—FP4 tucked tight.
I hop through graphs, unpick casts that sprawl,
Count batch-size footsteps down the quantized hall.
Carrots to ONNX, exported just right. 🥕🐇

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ajrasane/onnx_version_upgrade

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (8)
modelopt/onnx/quantization/qdq_utils.py (1)

928-930: Use op_type-based detection for Constant nodes
Replace the substring match on "Constant" in reshape_node.input with an explicit check of the producer node’s op_type to avoid relying on name patterns. For example:

- shape_constant_name = next(i for i in reshape_node.input if "Constant" in i)
+ shape_constant_name = next(
+     i for i in reshape_node.input
+     if tensor_producer_map[i].op_type == "Constant"
+ )
tests/unit/onnx/test_qdq_utils.py (4)

70-76: Redundant Cast between Reshape and Transpose (optional)

dq_output is already FLOAT; inserting weight_cast to FLOAT is a no-op unless it’s intentionally exercising cast-conversion logic. If this Cast is only for testing that path, consider a brief comment or renaming the node to indicate intent; otherwise, remove it to keep the pattern minimal (DQ → Reshape → Transpose).

Also applies to: 80-80


96-96: Avoid seeding an unused scale initializer when constant_scale=True

When constant_scale=True, the graph still includes a scale initializer up-front, which can make the “new scale initializer” assertion trivially pass. Suggest emitting it conditionally so the test precisely validates the pass rewires the scale.

Apply this diff:

-    nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node]
+    nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node]
@@
-        initializer=[weight_tensor, scale_tensor],
+        initializer=[weight_tensor] if constant_scale else [weight_tensor, scale_tensor],

Also applies to: 104-104


252-257: Name-based Cast preservation: tighten the match (nit)

if "norm/Cast" in node.name may over-match unrelated nodes. Consider anchoring (e.g., regex (^|/)layer_norm/Cast$) to avoid false positives as models grow.


315-316: FP4 tests look consistent with packed 4-bit along axis 0; add a negative/edge test

The 2D inputs and uint8 expectations align with _cast_fp4 packing the first dimension by 2. Add a failure case for odd first-dimension to lock the contract, and optionally a 4D case to exercise batching.

Example additions:

def test_cast_fp4_odd_first_dim_raises():
    with pytest.raises(AssertionError):
        _cast_fp4(np.zeros((3, 2), dtype=np.float32))

def test_cast_fp4_4d_batch():
    x = np.random.randn(2, 2, 2, 2).astype(np.float32)  # first dim even
    y = _cast_fp4(x)
    assert y.dtype == np.uint8
    assert y.shape[0] == 1 and y.shape[1:] == x.shape[1:]

Also applies to: 320-322, 325-327, 330-332, 335-337, 340-342, 348-348

examples/onnx_ptq/torch_quant_to_onnx.py (3)

86-92: Avoid heavy model instantiation just to read input size (optional)

Creating a full pretrained model here can be slow. If available in your timm version, prefer fetching the default cfg to get input_size without instantiating weights; otherwise, at least consider pretrained=False for this helper.


122-127: Add basic validation for --batch_size (and align with data size) (optional)

Help text implies constraints but they’re not enforced. Guard against invalid values and optionally round calibration_data_size to full batches for determinism.

Example (place after args = parser.parse_args()):

if args.batch_size <= 0:
    raise ValueError("--batch_size must be > 0")
if args.calibration_data_size <= 0:
    raise ValueError("--calibration_data_size must be > 0")
# Optional: ensure full batches only
# args.calibration_data_size = (args.calibration_data_size // args.batch_size) * args.batch_size

141-149: DataLoader returns GPU tensors from worker processes (risk of pickling/IPC issues)

load_calibration_data constructs CUDA tensors and then uses num_workers=4. Multiprocessing with GPU tensors can be brittle and memory-heavy. Prefer keeping tensors on CPU in the loader and moving them to device in the forward loop, or set num_workers=0 for simplicity.

Suggested adjustments:

  • Keep calibration tensors on CPU in load_calibration_data, and enable pin_memory=True:
return torch.utils.data.DataLoader(
    calib_tensor, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
  • Move batches to device inside forward_loop:
def forward_loop(model):
    for batch in data_loader:
        if isinstance(batch, torch.Tensor):
            batch = batch.to(device, non_blocking=True)
        model(batch)
  • Optionally set shuffle=False for reproducible calibration.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 76fb12d and d792d4f.

📒 Files selected for processing (5)
  • examples/onnx_ptq/torch_quant_to_onnx.py (2 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (11 hunks)
  • modelopt/torch/_deploy/utils/torch_onnx.py (1 hunks)
  • setup.py (1 hunks)
  • tests/unit/onnx/test_qdq_utils.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/test_qdq_utils.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
  • _cast_fp4 (614-626)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
  • NVFP4QTensor (31-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (11)
setup.py (1)

50-51: LGTM! Correct version upgrades for ONNX dependencies.

The ONNX version upgrades align with the PR objectives:

  • onnx~=1.19.0: Enables FP4 dtype support (FLOAT4E2M1)
  • onnxconverter-common~=1.16.0: Ensures compatibility with the new ONNX version
modelopt/onnx/quantization/qdq_utils.py (8)

49-57: LGTM! Proper addition of FP4 dtype mapping.

The addition of "Float4": onnx.TensorProto.FLOAT4E2M1 correctly extends the dtype map to support FP4 quantization, which is one of the key features in this ONNX upgrade.


604-612: Critical fix: FP8 casting now returns correct flat uint8 array.

The change from structured dtype to flat uint8 array aligns with ONNX's expected data format for FP8 tensors. This is a necessary fix for proper FP8 quantization.


693-694: LGTM! Consistent use of dtype map for FP8 checking.

The change to use onnx_dtype_map["Float8"] maintains consistency with the updated dtype handling throughout the codebase.


951-956: LGTM! Proper graph simplification by removing unnecessary Cast nodes.

The removal of Cast nodes between Reshape and Transpose operations optimizes the graph structure for INT4 quantization.


1007-1021: Good optimization: Removing redundant Cast after pre-quant scale.

The helper function is_pre_quant_scale_node correctly identifies and removes unnecessary Cast nodes following pre-quantization scale operations, improving graph efficiency.


1131-1138: LGTM! Proper FP8 tensor creation using ONNX helper.

The use of onnx.helper.make_tensor with explicit Float8 dtype and raw bytes ensures correct FP8 weight representation in the ONNX graph.


1219-1236: LGTM! Correct FP4 tensor creation with proper dimensions.

The FP4 tensor creation correctly:

  1. Uses onnx_dtype_map["Float4"] for the data type
  2. Adjusts dimensions to account for packing (2x values in first dim)
  3. Uses raw bytes for efficient storage

595-596: Float8 dtype check is backward compatible
onnx_dtype_map["Float8"] still resolves to the original constant (onnx.TensorProto.FLOAT8E4M3FN), and no other FLOAT8 variants or direct dtype checks were replaced—existing models remain supported.

modelopt/torch/_deploy/utils/torch_onnx.py (1)

488-489: Verify INT4 quantization detection and add tests
Ensure is_int4_quantized(model) is implemented and correctly flags INT4‐quantized models, and add unit tests covering the INT4→FP16 conversion path in the ONNX utility tests.

tests/unit/onnx/test_qdq_utils.py (1)

55-61: Good change: make Reshape shape a Constant input

Using a Constant for the reshape shape is cleaner and avoids managing a dedicated initializer. Wiring it into Reshape looks correct.

Also applies to: 65-65

Copy link

codecov bot commented Sep 5, 2025

Codecov Report

❌ Patch coverage is 65.90909% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.87%. Comparing base (0d279f1) to head (02d70b9).
⚠️ Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/qdq_utils.py 69.04% 13 Missing ⚠️
modelopt/torch/_deploy/utils/torch_onnx.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #289      +/-   ##
==========================================
- Coverage   73.94%   73.87%   -0.07%     
==========================================
  Files         172      172              
  Lines       17405    17438      +33     
==========================================
+ Hits        12870    12883      +13     
- Misses       4535     4555      +20     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the rabbit's comment

@ajrasane ajrasane force-pushed the ajrasane/onnx_version_upgrade branch from d792d4f to c06fcac Compare September 5, 2025 18:05
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
examples/onnx_ptq/torch_quant_to_onnx.py (1)

62-67: Don’t move dataset tensors to CUDA before DataLoader; move per-batch inside the loop

Creating a DataLoader over GPU tensors with num_workers>0 can hang/fail (IPC/pickling). Keep data on CPU, enable pin_memory, and move to device inside the forward loop.

-    calib_tensor = [t.to(device) for t in calib_tensor]
-    return torch.utils.data.DataLoader(
-        calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4
-    )
+    return torch.utils.data.DataLoader(
+        calib_tensor,
+        batch_size=batch_size,
+        shuffle=True,
+        num_workers=4,
+        pin_memory=(device.type == "cuda"),
+    )

And in quantize_model’s forward pass:

-        def forward_loop(model):
-            for batch in data_loader:
-                model(batch)
+        def forward_loop(model):
+            device = next(model.parameters()).device
+            with torch.inference_mode():
+                for batch in data_loader:
+                    if isinstance(batch, (list, tuple)):
+                        batch = torch.stack(batch)  # safety if collation returns list
+                    batch = batch.to(device, non_blocking=True)
+                    model(batch)
modelopt/onnx/quantization/qdq_utils.py (4)

629-633: Create FP8 initializer with Float8 dtype and raw bytes

Current code produces a UINT8 tensor. Use make_tensor with Float8 and raw bytes to preserve dtype.

-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
-    """Create a FLOAT8E4M3FN tensor directly from numpy array."""
-    fp8_data = _cast_fp8(scaled)
-    return onnx.numpy_helper.from_array(fp8_data, weight_name)
+def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
+    """Create a FLOAT8E4M3FN tensor with correct dtype and raw bytes."""
+    fp8_bytes = _cast_fp8(scaled).tobytes()
+    return onnx.helper.make_tensor(
+        name=weight_name,
+        data_type=onnx_dtype_map["Float8"],
+        dims=list(scaled.shape),
+        vals=fp8_bytes,
+        raw=True,
+    )

595-601: Bug: dtype check compares NumPy dtype to ONNX enum

zp_array.dtype == onnx_dtype_map["Float8"] will never be true. You lost the ONNX dtype when converting to NumPy. Propagate the TensorProto dtype and branch on that.

-    if zp_array.dtype == onnx_dtype_map["Float8"]:
+    if zp_dtype == onnx_dtype_map["Float8"]:
         scaled = np.asarray(weight_array / scale_array) + zp_array
     else:
         scaled = np.asarray((weight_array / scale_array).round())
         np.clip(scaled + zp_array, -128, 127, out=scaled)

Outside this hunk, update helpers to return/accept dtypes:

# In _get_scale_and_zp(...):
# return both arrays and their ONNX data_type enums
def _get_scale_and_zp(...)-> tuple[np.ndarray, np.ndarray, int, int]:
    ...
    scale_dtype = scale.data_type
    zp_dtype = zp.data_type
    scale_array = onnx.numpy_helper.to_array(scale)
    zp_array = onnx.numpy_helper.to_array(zp)
    return scale_array, zp_array, scale_dtype, zp_dtype

# Update _convert_weight signature to accept zp_dtype (int ONNX enum)
def _convert_weight(..., zp_dtype: int, ...) -> np.ndarray:
    ...

# And pass it from qdq_to_dq:
scale_array, zp_array, _, zp_dtype = _get_scale_and_zp(...)
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node, zp_dtype)

693-699: qDQ path must use ONNX dtype of zero-point, not NumPy dtype

Follow-on to the previous fix: branch on zp_dtype.

-            if zp_array.dtype == onnx_dtype_map["Float8"]:
+            if zp_dtype == onnx_dtype_map["Float8"]:
                 new_weight = _create_fp8_tensor(scaled, weight_name)
                 logger.debug(f"Converted {weight_name} to FP8")
             else:
                 new_weight = onnx.numpy_helper.from_array(scaled.astype("int8"), weight_name)
                 logger.debug(f"Converted {weight_name} to INT8")

604-611: Use explicit FLOAT8E4M3FN when creating the FP8 tensor
Replace the call to onnx.numpy_helper.from_array(fp8_data, weight_name) in _create_fp8_tensor with an explicit make_tensor invocation that sets data_type=onnx.TensorProto.FLOAT8E4M3FN, passes the raw bytes, and specifies the correct shape. For example:

import onnx

def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
    fp8_data = _cast_fp8(scaled)
    return onnx.helper.make_tensor(
        name=weight_name,
        data_type=onnx.TensorProto.FLOAT8E4M3FN,
        dims=fp8_data.shape,
        vals=fp8_data.tobytes(),
        raw=True,
    )

This ensures the initializer’s data_type is FLOAT8E4M3FN rather than UINT8.

♻️ Duplicate comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)

614-627: FP4 packing along wrong axis; assertion is on axis 0 instead of last axis; docstring missing requirement

FP4 should be packed along the last dimension (matches NVFP4QTensor.quantize). Asserting and reshaping on dim 0 will break many weight shapes. Also document the even-length requirement.

-def _cast_fp4(array: np.ndarray) -> np.ndarray:
-    """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
-    array_f32_t = torch.from_numpy(array)
-    array_f32_t_shape = array_f32_t.shape
-    assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
-    array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
-    if torch.cuda.is_available():
-        array_f32_t = array_f32_t.cuda()
-    array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
-    array_f4_t = array_f4_t.flatten()
-    array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
-    array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
-    return array_f4
+def _cast_fp4(array: np.ndarray) -> np.ndarray:
+    """Cast a numpy array to FLOAT4E2M1 using PyTorch.
+
+    Note: The last dimension must be even; two FP4 values are packed per byte.
+    """
+    array_f32_t = torch.from_numpy(array)
+    if array_f32_t.shape[-1] % 2 != 0:
+        raise ValueError(
+            f"Last dimension must be divisible by 2 for FP4 packing; got {array_f32_t.shape[-1]}"
+        )
+    if torch.cuda.is_available():
+        array_f32_t = array_f32_t.cuda()
+    q4 = NVFP4QTensor._cast_fp4(array_f32_t)  # values in [0..15], same shape as input
+    packed = (q4[..., 1::2] << 4) | q4[..., 0::2]  # pack along last dim
+    return packed.cpu().numpy().astype(np.uint8)   # shape: (*, last_dim//2)
🧹 Nitpick comments (8)
setup.py (1)

50-51: ONNX/ORT version alignment check

Bumping to onnx~=1.19.0 and onnxconverter-common~=1.16.0 looks fine. Please verify ORT packages still resolve and support new dtypes (Float4/Float8) used elsewhere. Also consider aligning onnxruntime-directml to ~=1.22.0 for consistency unless there’s a known constraint.

examples/onnx_ptq/torch_quant_to_onnx.py (3)

86-92: Input shape: validate batch_size > 0

Return shape logic is good. Add a guard for batch_size >= 1 to avoid silent mis-shapes.

 def get_model_input_shape(model_name, batch_size):
     """Get the input shape from timm model configuration."""
     model = timm.create_model(model_name, pretrained=True, num_classes=1000)
     data_config = timm.data.resolve_model_data_config(model)
     input_size = data_config["input_size"]
-    return (batch_size, *tuple(input_size))  # Add batch dimension
+    if batch_size < 1:
+        raise ValueError(f"batch_size must be >= 1, got {batch_size}")
+    return (batch_size, *tuple(input_size))  # Add batch dimension

122-127: CLI: constrain batch_size

Add argparse-level validation and clarify help.

 parser.add_argument(
     "--batch_size",
-    type=int,
-    default=1,
-    help="Batch size for calibration.",
+    type=int,
+    default=1,
+    choices=range(1, 1024),
+    metavar="{1..1023}",
+    help="Batch size for calibration (>=1).",
 )

131-150: Minor: set eval/inference during quantization

Set model.eval() before quantize and rely on inference_mode in forward_loop (as suggested above) to speed up and stabilize calibration behavior.

-    # Quantize model
-    quantized_model = quantize_model(model, config, data_loader)
+    model.eval()
+    quantized_model = quantize_model(model, config, data_loader)
modelopt/onnx/quantization/qdq_utils.py (4)

951-956: Assumes Cast always follows Reshape

Guard against missing/extra nodes to avoid crashes; gracefully skip if pattern doesn’t match.

-        # Remove unnecessary Cast node
-        cast_node = reshape_child_nodes[0]
-        assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
-        nodes_to_remove.append(cast_node.name)
-        cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
+        # Remove unnecessary Cast node (if present)
+        cast_child_nodes = reshape_child_nodes
+        if reshape_child_nodes and reshape_child_nodes[0].op_type == "Cast":
+            cast_node = reshape_child_nodes[0]
+            nodes_to_remove.append(cast_node.name)
+            cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]

958-977: Be robust if there is no Transpose after Cast

Avoid indexing without checks.

-        if cast_child_nodes[0].op_type == "Transpose":
-            transpose_node = cast_child_nodes[0]
+        if cast_child_nodes and cast_child_nodes[0].op_type == "Transpose":
+            transpose_node = cast_child_nodes[0]
             nodes_to_remove.append(transpose_node.name)
             ...
-        else:
-            matmul_node = cast_child_nodes[0]
+        else:
+            assert cast_child_nodes, f"No consumer found after Cast/Reshape for {node.name}"
+            matmul_node = cast_child_nodes[0]

1007-1021: Pre-quant scale cleanup assumes exactly one Cast child

Loosen the assumptions: skip if multiple/no children or child not a Cast; only rewire when safe.

-    for node in graph.node:
-        if is_pre_quant_scale_node(node):
-            pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
-            assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
-            cast_node = pqs_child_nodes[0]
-            assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
-            node.output.clear()
-            node.output.extend(cast_node.output)
-            nodes_to_remove.append(cast_node.name)
+    for node in graph.node:
+        if is_pre_quant_scale_node(node):
+            pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
+            if len(pqs_child_nodes) != 1 or pqs_child_nodes[0].op_type != "Cast":
+                continue
+            cast_node = pqs_child_nodes[0]
+            node.output[:] = cast_node.output
+            nodes_to_remove.append(cast_node.name)

1335-1343: Avoid magic number for BF16 detection

Use the enum for readability and safety.

-        for initializer in graph.initializer:
-            if initializer.data_type == 16:
+        for initializer in graph.initializer:
+            if initializer.data_type == onnx.TensorProto.BFLOAT16:
                 precision_dtype = "BFloat16"
                 break
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d792d4f and c06fcac.

📒 Files selected for processing (5)
  • examples/onnx_ptq/torch_quant_to_onnx.py (2 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (11 hunks)
  • modelopt/torch/_deploy/utils/torch_onnx.py (1 hunks)
  • setup.py (1 hunks)
  • tests/unit/onnx/test_qdq_utils.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • tests/unit/onnx/test_qdq_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
  • NVFP4QTensor (31-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/onnx/quantization/qdq_utils.py (3)

52-54: Add Float4 mapping

The Float4 → FLOAT4E2M1 mapping is expected for ONNX 1.19+. LGTM.


1131-1137: MXFP8 initializer creation looks correct

Using make_tensor with Float8 dtype and raw bytes is the right approach. LGTM.


1230-1236: FP8 scale initializer: consistent approach

Using Float8 dtype + raw bytes is consistent with MXFP8 path. LGTM.

Comment on lines +1219 to +1229
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

FP4 initializer dims should reflect packing along the last axis

After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly.

-    w_f4_proto = onnx.helper.make_tensor(
-        name=w_f4_name,
-        data_type=onnx_dtype_map["Float4"],
-        dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
-        vals=w_f4.tobytes(),
-        raw=True,
-    )
+    w_f4_proto = onnx.helper.make_tensor(
+        name=w_f4_name,
+        data_type=onnx_dtype_map["Float4"],
+        dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+        vals=w_f4.tobytes(),
+        raw=True,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
vals=w_f4.tobytes(),
raw=True,
)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 1219 to 1225, the FP4
initializer currently doubles the first dimension but FP4 packing was changed to
pack along the last axis; update the dims to reflect packing along the last axis
by replacing dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]] with
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2] (or equivalent list/tuple
construction) so the last dimension is doubled instead of the first.

@ajrasane ajrasane force-pushed the ajrasane/onnx_version_upgrade branch from 82d26ec to 3420d48 Compare September 5, 2025 22:43
@ajrasane ajrasane enabled auto-merge (squash) September 5, 2025 22:44
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/quantization/qdq_utils.py (2)

595-601: Bug: FP8 detection via NumPy dtype is incorrect

Comparing zp_array.dtype to an ONNX enum won’t work; NumPy returns uint8 (or a structured uint8), not an ONNX dtype. This branch will never select the FP8 path.

Fix by checking the zero-point TensorProto’s data_type (or pass a boolean). Example patch:

-def _get_scale_and_zp(
+def _get_scale_and_zp(
     node: onnx.NodeProto,
     initializers: dict[str, onnx.TensorProto],
     tensor_producers: dict[str, onnx.NodeProto],
-) -> tuple[np.ndarray, np.ndarray]:
+) -> tuple[np.ndarray, np.ndarray, int]:
@@
-    return scale_array, zp_array
+    return scale_array, zp_array, zp.data_type
@@
-def _convert_weight(
+def _convert_weight(
     weight_array: np.ndarray,
     scale_array: np.ndarray,
-    zp_array: np.ndarray,
+    zp_array: np.ndarray,
+    zp_dtype: int,
     quantized_node: onnx.NodeProto,
 ) -> np.ndarray:
@@
-    if zp_array.dtype == onnx_dtype_map["Float8"]:
+    if zp_dtype == onnx_dtype_map["Float8"]:
         scaled = np.asarray(weight_array / scale_array) + zp_array
     else:
         scaled = np.asarray((weight_array / scale_array).round())
         np.clip(scaled + zp_array, -128, 127, out=scaled)
@@
-            scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers)
+            scale_array, zp_array, zp_dtype = _get_scale_and_zp(node, initializers, tensor_producers)
@@
-            scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
+            scaled = _convert_weight(weight_array, scale_array, zp_array, zp_dtype, quantized_node)

633-637: Create FP8 initializers with proper ONNX dtype, not numpy_helper.from_array

numpy_helper.from_array will tag data as UINT8, not Float8. Use make_tensor with data_type Float8 and raw bytes (as done below in quantize_weights_to_mxfp8).

-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
-    """Create a FLOAT8E4M3FN tensor directly from numpy array."""
-    fp8_data = _cast_fp8(scaled)
-    return onnx.numpy_helper.from_array(fp8_data, weight_name)
+def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
+    """Create a FLOAT8E4M3FN initializer with correct dtype."""
+    fp8_data = _cast_fp8(scaled)
+    return onnx.helper.make_tensor(
+        name=weight_name,
+        data_type=onnx_dtype_map["Float8"],
+        dims=[*scaled.shape],
+        vals=fp8_data.tobytes(),
+        raw=True,
+    )
♻️ Duplicate comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)

932-935: Reshape shape Constant detection is brittle (re-raising prior feedback)

String-matching "Constant" in input names is fragile. Use Reshape’s second input and check its producer.

-        # Remove constant node from reshape node
-        shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
-        nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
+        # Remove Constant that feeds Reshape's shape, if present
+        if len(reshape_node.input) >= 2:
+            shape_name = reshape_node.input[1]
+            shape_producer = tensor_producer_map.get(shape_name)
+            if shape_producer is not None and shape_producer.op_type == "Constant":
+                nodes_to_remove.append(shape_producer.name)
🧹 Nitpick comments (2)
tests/unit/onnx/test_qdq_utils.py (1)

315-342: FP4 tests: packed shape and dtype OK; add an odd-first-dim guard test

  • Expecting a packed shape along the first axis (2xN → 1xN) and dtype uint8 matches _cast_fp4’s contract.
  • Consider adding a negative test where the first dim is odd to assert the error path.

I can add a parametric test that asserts the raised error for odd first-dimension inputs. Want me to push it?

Also applies to: 348-348

modelopt/onnx/quantization/qdq_utils.py (1)

1223-1229: Confirm FP4 initializer dims align with packing axis

Dims double the first axis, consistent with _cast_fp4’s first-axis packing. Verify all consumers assume this convention; earlier feedback suggested last-axis packing—ensure consistency across exporters/importers.

If you plan to switch packing to the last axis later, centralize the “packed axis” in one utility to avoid mismatches.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 82d26ec and 3420d48.

📒 Files selected for processing (5)
  • examples/onnx_ptq/torch_quant_to_onnx.py (2 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (11 hunks)
  • modelopt/torch/_deploy/utils/torch_onnx.py (1 hunks)
  • setup.py (1 hunks)
  • tests/unit/onnx/test_qdq_utils.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/onnx_ptq/torch_quant_to_onnx.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • setup.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)
  • NVFP4QTensor (31-295)
tests/unit/onnx/test_qdq_utils.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
  • _cast_fp4 (614-630)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
tests/unit/onnx/test_qdq_utils.py (4)

55-66: Reshape shape via Constant looks good

Using a Constant for the reshape shape and wiring it as the second input of Reshape aligns with ONNX patterns and avoids depending on initializers.


70-81: Explicit Cast between Reshape and Transpose is fine

Adding the Cast and feeding Transpose from it mirrors the production graphs these tests target. No issues spotted.


96-96: Node/initializer lists updated correctly

Including the new Constant/Cast nodes in nodes and removing the reshape shape initializer is consistent with the refactor intent.

Also applies to: 104-104


252-257: Cast skip rule matches implementation

Tests now only preserve FP32 casts whose names contain "norm/Cast"; all other FP32 casts are converted to FP16. This matches the quantization pass behavior.

modelopt/onnx/quantization/qdq_utils.py (3)

52-54: ONNX Float4 mapping added

Mapping "Float4" to FLOAT4E2M1 is correct for ONNX ≥ 1.19.

If CI uses multiple ONNX versions, please ensure FLOAT4E2M1 exists at runtime to avoid import-time failures.


1135-1141: Good: FP8 initializer created with correct dtype and raw bytes

This mirrors the recommended pattern for Float8 in ONNX.


610-611: Return structured uint8 dtype in _cast_fp8
Update _cast_fp8 to return a structured‐dtype view so tests pass and serialization still works:

-    array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
-    return array_f8
+    array_f8 = array_f8_t.cpu().numpy()
+    return array_f8.view(np.dtype((np.uint8, [("e4m3fn", "u1")])))

Comment on lines +1011 to +1025
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input

# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Pre-quant Mul detection is wrong; the naming doesn’t match inserted nodes

The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger.

Make detection structural: Mul with one initializer or Constant input.

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False

Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers.


🏁 Script executed:

#!/bin/bash

echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true

echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true

echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

Length of output: 28239


Use structural detection for pre-quant Mul in qdq_utils.py
The current is_pre_quant_scale_node looks for “_pre_quant_scale” in the input name, but the inserted scale tensors are named with “_awq_scale”, so this code never matches and the Cast isn’t removed. Replace it with a structural check—i.e. a Mul with exactly two inputs, one coming from an initializer or Constant node:

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input
# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
# A pre-quantization scale is always a Mul with one constant input
if node.op_type != "Mul" or len(node.input) != 2:
return False
for inp in node.input:
# Check if one of the inputs is a graph initializer
if inp in initializer_map:
return True
# Or produced by a Constant node
prod = tensor_producer_map.get(inp)
if prod is not None and prod.op_type == "Constant":
return True
return False
# Remove unnecessary Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

@ajrasane ajrasane disabled auto-merge September 6, 2025 00:32
@gcunhase gcunhase self-requested a review September 9, 2025 00:07
Copy link
Contributor

@gcunhase gcunhase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ajrasane ajrasane requested a review from a team as a code owner September 9, 2025 08:41
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/_test_utils/import_helper.py (1)

80-92: Allow safe module-level skipping.

If this helper is ever used at import time, current skips lack allow_module_level=True. Harmless inside tests, but brittle at module level.

-    except importlib.metadata.PackageNotFoundError:
-        pytest.skip(f"{package_name} is not installed")
+    except importlib.metadata.PackageNotFoundError:
+        pytest.skip(f"{package_name} is not installed", allow_module_level=True)

-    if version.parse(installed_version) < version.parse(required_version):
-        pytest.skip(
-            f"{package_name} version {installed_version} is less than required {required_version}"
-        )
+    if version.parse(installed_version) < version.parse(required_version):
+        pytest.skip(
+            f"{package_name} version {installed_version} is less than required {required_version}",
+            allow_module_level=True,
+        )
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (2)

43-44: Gate ONNX version at import time to avoid ImportError before skip.

Current skip runs inside the test; onnx-dependent imports above may fail first in envs with <1.19 or missing onnx. Move the gate to module level and drop per-test calls.

-    skip_if_onnx_version_below_1_19()

Add near the imports (example):

# Top-level gate (place after standard imports)
from _test_utils.import_helper import skip_if_onnx_version_below_1_19
skip_if_onnx_version_below_1_19()

Note: This pairs with adding allow_module_level=True in the helper as suggested.


119-120: Same as above: move version gate to module level.

Prevents import-time failures and keeps skip reason centralized.

-    skip_if_onnx_version_below_1_19()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3420d48 and ed78802.

📒 Files selected for processing (3)
  • examples/windows/onnx_ptq/genai_llm/requirements.txt (1 hunks)
  • tests/_test_utils/import_helper.py (2 hunks)
  • tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
tests/_test_utils/import_helper.py (2)
  • skip_if_no_libcudnn (37-43)
  • skip_if_onnx_version_below_1_19 (80-92)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
tests/_test_utils/import_helper.py (1)

16-20: Ensure test dependency availability.

packaging is used at test time. Confirm it’s declared in test/extras to prevent CI env drift.

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)

23-23: Import looks good.

@@ -1,3 +1,4 @@
datasets>=2.14.5
onnx==1.18.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix version mismatch with PR objective (onnx 1.19).

This example pins onnx==1.18.0 while the PR upgrades repo/tooling to 1.19.0 and gates tests on >=1.19. Align to avoid feature/API skew (e.g., FP4/INT4 utilities).

-onnx==1.18.0
+onnx==1.19.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
onnx==1.18.0
++ b/examples/windows/onnx_ptq/genai_llm/requirements.txt
@@ -1,3 +1,3 @@
onnx==1.19.0
🤖 Prompt for AI Agents
In examples/windows/onnx_ptq/genai_llm/requirements.txt around line 2, the file
pins onnx==1.18.0 which mismatches the repo/test expectation of onnx>=1.19.0;
update the requirement to onnx==1.19.0 (or onnx>=1.19.0 if a range is preferred)
so the example aligns with the PR tooling/tests and avoids API/feature skew
(e.g., FP4/INT4 utilities).

@ajrasane ajrasane force-pushed the ajrasane/onnx_version_upgrade branch from 03d342d to c1ea892 Compare September 9, 2025 23:18
@ajrasane ajrasane force-pushed the ajrasane/onnx_version_upgrade branch from c1ea892 to 02d70b9 Compare September 9, 2025 23:19
@ajrasane ajrasane enabled auto-merge (squash) September 9, 2025 23:20
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/_test_utils/import_helper.py (1)

16-20: Imports look fine; optional fallback for older Pythons.

If Python <3.8 is still in scope anywhere, add a fallback to importlib_metadata backport. Otherwise, keep as-is.

Apply if needed:

+try:
+    import importlib.metadata as importlib_metadata
+except ModuleNotFoundError:  # pragma: no cover
+    import importlib_metadata  # type: ignore
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)

42-44: Consider skipping at collection time with a decorator to reduce duplication.

Optional: switch to @pytest.mark.skipif(...) (using a small predicate in the helper) so tests are marked-skipped at collection, not runtime. Keeps both tests consistent and avoids executing setup code before the skip.

Example approach:

- def test_int4_awq(tmp_path):
-     skip_if_onnx_version_above_1_18()
+ @pytest.mark.skipif(_onnx_gt_1_18(), reason="Unsupported for ONNX > 1.18.0")
+ def test_int4_awq(tmp_path):
     ...
- def test_int4_awq_cuda(tmp_path):
-     skip_if_onnx_version_above_1_18()
+ @pytest.mark.skipif(_onnx_gt_1_18(), reason="Unsupported for ONNX > 1.18.0")
+ def test_int4_awq_cuda(tmp_path):

Helper (in import_helper.py):

def _onnx_gt_1_18() -> bool:
    try:
        v = importlib.metadata.version("onnx")
    except importlib.metadata.PackageNotFoundError:
        return False
    return version.parse(v) > version.parse("1.18.0")

Also applies to: 118-121

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ed78802 and 02d70b9.

📒 Files selected for processing (3)
  • examples/windows/onnx_ptq/genai_llm/requirements.txt (1 hunks)
  • tests/_test_utils/import_helper.py (2 hunks)
  • tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/windows/onnx_ptq/genai_llm/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
tests/_test_utils/import_helper.py (2)
  • skip_if_no_libcudnn (37-43)
  • skip_if_onnx_version_above_1_18 (80-92)
🔇 Additional comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)

23-23: LGTM: Brings in the version gate helper where needed.

Comment on lines +80 to +92
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
required_version = "1.18.0"

try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed")

if version.parse(installed_version) > version.parse(required_version):
pytest.skip(
f"{package_name} version {installed_version} is less than required {required_version}"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Skip reason text is wrong; variable name misleads; add allow_module_level for consistency.

Condition skips when ONNX > 1.18, but the message says “less than required”. Rename to reflect max supported, fix message, and pass allow_module_level=True.

 def skip_if_onnx_version_above_1_18():
     package_name = "onnx"
-    required_version = "1.18.0"
+    max_supported_version = "1.18.0"

     try:
         installed_version = importlib.metadata.version(package_name)
     except importlib.metadata.PackageNotFoundError:
-        pytest.skip(f"{package_name} is not installed")
+        pytest.skip(f"{package_name} is not installed", allow_module_level=True)

-    if version.parse(installed_version) > version.parse(required_version):
+    if version.parse(installed_version) > version.parse(max_supported_version):
         pytest.skip(
-            f"{package_name} version {installed_version} is less than required {required_version}"
+            f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
+            allow_module_level=True,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
required_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed")
if version.parse(installed_version) > version.parse(required_version):
pytest.skip(
f"{package_name} version {installed_version} is less than required {required_version}"
)
def skip_if_onnx_version_above_1_18():
package_name = "onnx"
max_supported_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
pytest.skip(f"{package_name} is not installed", allow_module_level=True)
if version.parse(installed_version) > version.parse(max_supported_version):
pytest.skip(
f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
allow_module_level=True,
)
🤖 Prompt for AI Agents
In tests/_test_utils/import_helper.py around lines 80 to 92, the helper misnames
the version variable and logs an incorrect skip message and omits
allow_module_level; rename required_version to max_supported_version (or
similar), update the skip message to say the installed ONNX version is greater
than the max supported (include installed_version and max_supported_version),
and call pytest.skip(..., allow_module_level=True) when skipping due to version
being above the supported maximum.

@ajrasane ajrasane merged commit 0adb1b6 into main Sep 10, 2025
22 checks passed
@ajrasane ajrasane deleted the ajrasane/onnx_version_upgrade branch September 10, 2025 00:40
jingyu-ml pushed a commit that referenced this pull request Sep 10, 2025
Signed-off-by: ajrasane <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
benchislett pushed a commit that referenced this pull request Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants