Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 12 additions & 31 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,17 +1111,20 @@ def quantize_weights_to_int4(
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
scale = scale.reshape(scale_shape)
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}"

# Remove unnecessary Cast node
cast_node = reshape_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
nodes_to_remove.append(cast_node.name)
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
# Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm
next_node = reshape_child_nodes[0]
if next_node.op_type == "Cast":
# Remove unnecessary Cast node
cast_node = next_node
nodes_to_remove.append(cast_node.name)
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
next_node = cast_child_nodes[0]

# Transpose weights and scales if present
if cast_child_nodes[0].op_type == "Transpose":
transpose_node = cast_child_nodes[0]
if next_node.op_type == "Transpose":
transpose_node = next_node
nodes_to_remove.append(transpose_node.name)
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
perm = None
Expand All @@ -1138,7 +1141,7 @@ def quantize_weights_to_int4(
)
matmul_node = transpose_child_nodes[0]
else:
matmul_node = cast_child_nodes[0]
matmul_node = next_node
assert matmul_node.op_type in ["MatMul", "Gemm"], (
f"Expected MatMul or Gemm node for {node.name}"
)
Expand Down Expand Up @@ -1189,21 +1192,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
del graph.node[:]
graph.node.extend(new_nodes)

def is_fp32_cast(node: onnx.NodeProto) -> bool:
return any(
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
)

# Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16)
for node in graph.node:
if node.op_type == "Cast":
# Skip Cast nodes that are part of normalization layers and outputs
if "norm/Cast" in node.name and is_fp32_cast(node):
continue
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
attr.i = onnx.TensorProto.FLOAT16

# Cast bias to float16
for node in graph.node:
if node.op_type == "Add" and "proj/Add" in node.name:
Expand Down Expand Up @@ -1310,13 +1298,6 @@ def quantize_weights_to_mxfp8(
if attr.name == "output_dtype":
attr.i = onnx_dtype_map["Half"]

# set Cast to FP16
for node in graph.node:
if node.op_type == "Cast":
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
attr.i = onnx_dtype_map["Half"]

# Currently only tanh approximation is supported for Gelu
for node in gelu_nodes:
for attr in node.attribute:
Expand Down
Loading