-
Notifications
You must be signed in to change notification settings - Fork 162
Upgrade to ONNX 1.19.0 #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,7 +23,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import onnx_graphsurgeon as gs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from onnx import numpy_helper | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from onnx.reference.custom_element_types import float8e4m3fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.onnx import utils | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.onnx.logging_config import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -50,6 +49,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
onnx_dtype_map = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"BFloat16": onnx.TensorProto.BFLOAT16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float": onnx.TensorProto.FLOAT, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float4": onnx.TensorProto.FLOAT4E2M1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Float8": onnx.TensorProto.FLOAT8E4M3FN, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Half": onnx.TensorProto.FLOAT16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"INT8": onnx.TensorProto.INT8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -592,7 +592,7 @@ def _convert_weight( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zp_array = zp_array.reshape(*reshape_dims) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Convert to INT8/FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == float8e4m3fn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == onnx_dtype_map["Float8"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = np.asarray(weight_array / scale_array) + zp_array | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = np.asarray((weight_array / scale_array).round()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = array_f32_t.cuda() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return array_f8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _cast_fp4(array: np.ndarray) -> np.ndarray: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Cast a numpy array to FLOAT4E2M1 using PyTorch. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Note: The first dimension of the array must be divisible by 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
as two FP4 values are packed into a single byte. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = torch.from_numpy(array) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t_shape = array_f32_t.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f32_t = array_f32_t.cuda() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t = array_f4_t.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return array_f4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Create and update new weight tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == float8e4m3fn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if zp_array.dtype == onnx_dtype_map["Float8"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_weight = _create_fp8_tensor(scaled, weight_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to FP8") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -920,6 +929,10 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_node_output = reshape_node.output[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove constant node from reshape node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Get the shape of the output of the reshape node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_output_value_info = value_info_map.get(reshape_node_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if reshape_output_value_info is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -937,12 +950,17 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scale = scale.reshape(scale_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# reshape_node.input = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove unnecessary Cast node | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(cast_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Transpose weights and scales if present | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if reshape_child_nodes[0].op_type == "Transpose": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transpose_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if cast_child_nodes[0].op_type == "Transpose": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transpose_node = cast_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(transpose_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
perm = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -959,7 +977,7 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = transpose_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = reshape_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
matmul_node = cast_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert matmul_node.op_type in ["MatMul", "Gemm"], ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Expected MatMul or Gemm node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -990,6 +1008,21 @@ def quantize_weights_to_int4( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
initializer_map[weight_name].CopyFrom(weights_int4_onnx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to INT4 precision") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return node.op_type == "Mul" and has_pqs_input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove unnecessay Cast after Pre-quant scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for node in graph.node: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if is_pre_quant_scale_node(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cast_node = pqs_child_nodes[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
node.output.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
node.output.extend(cast_node.output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nodes_to_remove.append(cast_node.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1011
to
+1025
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainPre-quant Mul detection is wrong; the naming doesn’t match inserted nodes The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger. Make detection structural: Mul with one initializer or Constant input. - def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers. 🏁 Script executed: #!/bin/bash
echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true
echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true
echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true
echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true Length of output: 28239 Use structural detection for pre-quant Mul in qdq_utils.py - def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
- has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
- return node.op_type == "Mul" and has_pqs_input
+ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+ if node.op_type != "Mul" or len(node.input) != 2:
+ return False
+ for inp in node.input:
+ if inp in initializer_map:
+ return True
+ prod = tensor_producer_map.get(inp)
+ if prod is not None and prod.op_type == "Constant":
+ return True
+ return False 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Remove transpose and reshape nodes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
graph.node.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1004,7 +1037,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for node in graph.node: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if node.op_type == "Cast": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Skip Cast nodes that are part of normalization layers and outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if "norm/Cast" in node.name and is_fp32_cast(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for attr in node.attribute: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1099,7 +1132,13 @@ def quantize_weights_to_mxfp8( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Expand block array so that it can be broadcasted with weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
weights_e4m3 = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=weight_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float8"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[*scaled_weight.shape], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=_cast_fp8(scaled_weight).tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
initializer_map[weight_name].CopyFrom(weights_e4m3) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(f"Converted {weight_name} to MXFP8") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1181,11 +1220,24 @@ def _add_input_value_info(graph, tensor_proto): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Create TensorProto for initializers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_f4_proto = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=w_f4_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float4"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=w_f4.tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1223
to
+1229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion FP4 initializer dims should reflect packing along the last axis After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly. - w_f4_proto = onnx.helper.make_tensor(
- name=w_f4_name,
- data_type=onnx_dtype_map["Float4"],
- dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
- vals=w_f4.tobytes(),
- raw=True,
- )
+ w_f4_proto = onnx.helper.make_tensor(
+ name=w_f4_name,
+ data_type=onnx_dtype_map["Float4"],
+ dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+ vals=w_f4.tobytes(),
+ raw=True,
+ ) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor_proto = onnx.numpy_helper.from_array( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f32_per_tensor, sw_f32_per_tensor_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sw_f8_per_block_proto = onnx.helper.make_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=sw_f8_per_block_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type=onnx_dtype_map["Float8"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims=[*sw_f8_per_block.shape], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vals=sw_f8_per_block.tobytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Add ValueInfo for the initializers if not present | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_add_input_value_info(graph, w_f4_proto) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.