Skip to content

Commit 1ab62d5

Browse files
committed
ONNX 1.19 Fix: Changed IR version to 10 to be compatible with onnxruntime and added axis, block size attributes to dq node
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent 10e6894 commit 1ab62d5

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _make_constant(
4848
setattr(t, "explicit_dtype", dtype)
4949
return t
5050

51-
5251
def _make_variable(
5352
name: str, dtype: onnx.TensorProto.DataType, shape: Sequence[int | str]
5453
) -> gs.Constant:

modelopt/onnx/quantization/int4.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@
9898
# supported and working
9999
CLIP_MIN = 1e-5
100100

101-
102101
def safe_cupy_array(tensor):
103102
"""Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104103
@@ -294,19 +293,23 @@ def quantize_rtn(
294293
scales[name] = np.asnumpy(scales[name])
295294
gemm_weights_quantized[name] = numpy.asarray(qw)
296295
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
296+
dq_node_attributes = {"axis": 0, "block_size": block_size}
297297
qdq.insert_dq_nodes(
298298
graph,
299299
scales,
300300
quantized_weights=gemm_weights_quantized,
301+
attributes=dq_node_attributes,
301302
precision_info=precision_info,
302303
)
303304

304305
if gather_w_map is not None:
305306
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
307+
gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size}
306308
qdq.insert_dq_nodes(
307309
graph,
308310
gather_s_map,
309311
quantized_weights=gather_w_map,
312+
attributes=gather_dq_node_attributes,
310313
precision_info=precision_info,
311314
)
312315
else:
@@ -322,8 +325,10 @@ def quantize_rtn(
322325
)
323326

324327
logger.info(f"RTN quantization completed in {time.time() - t_start:.2f} seconds")
325-
return gs.export_onnx(graph)
328+
model = gs.export_onnx(graph)
329+
model.ir_version = 10
326330

331+
return model
327332

328333
class AWQClipHelper:
329334
"""AWQ calibration helper class."""

0 commit comments

Comments
 (0)