Skip to content

Commit 639a0f3

Browse files
committed
fix qdq utils issues and remove global cast replacements
Signed-off-by: Luxiao Zheng <[email protected]>
1 parent b660d39 commit 639a0f3

File tree

2 files changed

+346
-55
lines changed

2 files changed

+346
-55
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,17 +1107,20 @@ def quantize_weights_to_int4(
11071107
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
11081108
scale = scale.reshape(scale_shape)
11091109
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
1110-
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
1110+
assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
11111111

1112-
# Remove unnecessary Cast node
1113-
cast_node = reshape_child_nodes[0]
1114-
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
1115-
nodes_to_remove.append(cast_node.name)
1116-
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
1112+
# Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm
1113+
next_node = reshape_child_nodes[0]
1114+
if next_node.op_type == "Cast":
1115+
# Remove unnecessary Cast node
1116+
cast_node = next_node
1117+
nodes_to_remove.append(cast_node.name)
1118+
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
1119+
next_node = cast_child_nodes[0]
11171120

11181121
# Transpose weights and scales if present
1119-
if cast_child_nodes[0].op_type == "Transpose":
1120-
transpose_node = cast_child_nodes[0]
1122+
if next_node.op_type == "Transpose":
1123+
transpose_node = next_node
11211124
nodes_to_remove.append(transpose_node.name)
11221125
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
11231126
perm = None
@@ -1134,7 +1137,7 @@ def quantize_weights_to_int4(
11341137
)
11351138
matmul_node = transpose_child_nodes[0]
11361139
else:
1137-
matmul_node = cast_child_nodes[0]
1140+
matmul_node = next_node
11381141
assert matmul_node.op_type in ["MatMul", "Gemm"], (
11391142
f"Expected MatMul or Gemm node for {node.name}"
11401143
)
@@ -1185,21 +1188,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
11851188
del graph.node[:]
11861189
graph.node.extend(new_nodes)
11871190

1188-
def is_fp32_cast(node: onnx.NodeProto) -> bool:
1189-
return any(
1190-
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
1191-
)
1192-
1193-
# Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16)
1194-
for node in graph.node:
1195-
if node.op_type == "Cast":
1196-
# Skip Cast nodes that are part of normalization layers and outputs
1197-
if "norm/Cast" in node.name and is_fp32_cast(node):
1198-
continue
1199-
for attr in node.attribute:
1200-
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
1201-
attr.i = onnx.TensorProto.FLOAT16
1202-
12031191
# Cast bias to float16
12041192
for node in graph.node:
12051193
if node.op_type == "Add" and "proj/Add" in node.name:
@@ -1306,13 +1294,6 @@ def quantize_weights_to_mxfp8(
13061294
if attr.name == "output_dtype":
13071295
attr.i = onnx_dtype_map["Half"]
13081296

1309-
# set Cast to FP16
1310-
for node in graph.node:
1311-
if node.op_type == "Cast":
1312-
for attr in node.attribute:
1313-
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
1314-
attr.i = onnx_dtype_map["Half"]
1315-
13161297
# Currently only tanh approximation is supported for Gelu
13171298
for node in gelu_nodes:
13181299
for attr in node.attribute:

0 commit comments

Comments
 (0)