-
Notifications
You must be signed in to change notification settings - Fork 162
Upgrade to ONNX 1.19.0 #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
datasets>=2.14.5 | ||
onnx==1.18.0 | ||
torch==2.6.0 | ||
transformers==4.49.0 |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,7 +23,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import onnx_graphsurgeon as gs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from onnx import numpy_helper | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from onnx.reference.custom_element_types import float8e4m3fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.onnx import utils | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.onnx.logging_config import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -50,6 +49,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
onnx_dtype_map = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"BFloat16": onnx.TensorProto.BFLOAT16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float": onnx.TensorProto.FLOAT, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float4": onnx.TensorProto.FLOAT4E2M1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float8": onnx.TensorProto.FLOAT8E4M3FN, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Half": onnx.TensorProto.FLOAT16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"INT8": onnx.TensorProto.INT8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -592,7 +592,7 @@ def _convert_weight( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zp_array = zp_array.reshape(*reshape_dims) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Convert to INT8/FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == float8e4m3fn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == onnx_dtype_map["Float8"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = np.asarray(weight_array / scale_array) + zp_array | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = np.asarray((weight_array / scale_array).round()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = array_f32_t.cuda() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return array_f8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _cast_fp4(array: np.ndarray) -> np.ndarray: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Cast a numpy array to FLOAT4E2M1 using PyTorch. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Note: The first dimension of the array must be divisible by 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
as two FP4 values are packed into a single byte. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = torch.from_numpy(array) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t_shape = array_f32_t.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = array_f32_t.cuda() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t = array_f4_t.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ajrasane marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return array_f4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Create and update new weight tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == float8e4m3fn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == onnx_dtype_map["Float8"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_weight = _create_fp8_tensor(scaled, weight_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to FP8") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -920,6 +929,10 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_node_output = reshape_node.output[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove constant node from reshape node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Get the shape of the output of the reshape node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_output_value_info = value_info_map.get(reshape_node_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if reshape_output_value_info is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -937,12 +950,17 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scale = scale.reshape(scale_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# reshape_node.input = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove unnecessary Cast node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(cast_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Transpose weights and scales if present | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if reshape_child_nodes[0].op_type == "Transpose": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transpose_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if cast_child_nodes[0].op_type == "Transpose": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transpose_node = cast_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(transpose_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
perm = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -959,7 +977,7 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = transpose_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = cast_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert matmul_node.op_type in ["MatMul", "Gemm"], ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Expected MatMul or Gemm node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -990,6 +1008,21 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
initializer_map[weight_name].CopyFrom(weights_int4_onnx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to INT4 precision") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return node.op_type == "Mul" and has_pqs_input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove unnecessay Cast after Pre-quant scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for node in graph.node: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if is_pre_quant_scale_node(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_node = pqs_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
node.output.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
node.output.extend(cast_node.output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(cast_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1011
to
+1025
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainPre-quant Mul detection is wrong; the naming doesn’t match inserted nodes The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger. Make detection structural: Mul with one initializer or Constant input. - def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers. 🏁 Script executed: #!/bin/bash
echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true
echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true
echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true Length of output: 28239 Use structural detection for pre-quant Mul in qdq_utils.py - def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove transpose and reshape nodes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
graph.node.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1004,7 +1037,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for node in graph.node: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if node.op_type == "Cast": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Skip Cast nodes that are part of normalization layers and outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if "norm/Cast" in node.name and is_fp32_cast(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for attr in node.attribute: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1099,7 +1132,13 @@ def quantize_weights_to_mxfp8( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Expand block array so that it can be broadcasted with weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
weights_e4m3 = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=weight_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float8"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[*scaled_weight.shape], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=_cast_fp8(scaled_weight).tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
initializer_map[weight_name].CopyFrom(weights_e4m3) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to MXFP8") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1181,11 +1220,24 @@ def _add_input_value_info(graph, tensor_proto): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Create TensorProto for initializers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_f4_proto = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=w_f4_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float4"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=w_f4.tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1223
to
+1229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion FP4 initializer dims should reflect packing along the last axis After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly. - w_f4_proto = onnx.helper.make_tensor(
- name=w_f4_name,
- data_type=onnx_dtype_map["Float4"],
- dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
- vals=w_f4.tobytes(),
- raw=True,
- )
+ w_f4_proto = onnx.helper.make_tensor(
+ name=w_f4_name,
+ data_type=onnx_dtype_map["Float4"],
+ dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+ vals=w_f4.tobytes(),
+ raw=True,
+ ) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor_proto = onnx.numpy_helper.from_array( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor, sw_f32_per_tensor_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f8_per_block_proto = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=sw_f8_per_block_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float8"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[*sw_f8_per_block.shape], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=sw_f8_per_block.tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Add ValueInfo for the initializers if not present | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_add_input_value_info(graph, w_f4_proto) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,9 +13,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import importlib.metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import shutil | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from packaging import version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def skip_if_no_tensorrt(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -73,3 +75,18 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if mamba_required and not has_mamba: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pytest.skip("Mamba required for Megatron test", allow_module_level=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def skip_if_onnx_version_above_1_18(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
package_name = "onnx" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
required_version = "1.18.0" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
installed_version = importlib.metadata.version(package_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
except importlib.metadata.PackageNotFoundError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pytest.skip(f"{package_name} is not installed") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if version.parse(installed_version) > version.parse(required_version): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pytest.skip( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"{package_name} version {installed_version} is less than required {required_version}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+80
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skip reason text is wrong; variable name misleads; add allow_module_level for consistency. Condition skips when ONNX > 1.18, but the message says “less than required”. Rename to reflect max supported, fix message, and pass def skip_if_onnx_version_above_1_18():
package_name = "onnx"
- required_version = "1.18.0"
+ max_supported_version = "1.18.0"
try:
installed_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
- pytest.skip(f"{package_name} is not installed")
+ pytest.skip(f"{package_name} is not installed", allow_module_level=True)
- if version.parse(installed_version) > version.parse(required_version):
+ if version.parse(installed_version) > version.parse(max_supported_version):
pytest.skip(
- f"{package_name} version {installed_version} is less than required {required_version}"
+ f"{package_name} version {installed_version} > supported {max_supported_version}; expected <= {max_supported_version}",
+ allow_module_level=True,
) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix version mismatch with PR objective (onnx 1.19).
This example pins onnx==1.18.0 while the PR upgrades repo/tooling to 1.19.0 and gates tests on >=1.19. Align to avoid feature/API skew (e.g., FP4/INT4 utilities).
📝 Committable suggestion
🤖 Prompt for AI Agents