Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import numpy as np
import onnx
import onnx_graphsurgeon as gs
from onnx import TensorProto, ValueInfoProto, numpy_helper
from onnx.helper import get_attribute_value
from onnx_graphsurgeon import Constant, Node, Variable

Expand Down Expand Up @@ -289,7 +288,7 @@ def _convert_types_to_np(types: dict[str, int] | list[int] | int) -> Any:

def get_tensor_by_name(
onnx_model: onnx.ModelProto, tensor_name: str
) -> ValueInfoProto | TensorProto | None:
) -> onnx.ValueInfoProto | onnx.TensorProto | None:
"""This function returns a tensor from its name.

This function searches for a tensor in the model's:
Expand Down Expand Up @@ -438,7 +437,7 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes:
numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype(
dtype
)
tensor = numpy_helper.from_array(numpy_array, init.name)
tensor = onnx.numpy_helper.from_array(numpy_array, init.name)
model.graph.initializer[idx].CopyFrom(tensor)

buffer = io.BytesIO()
Expand Down Expand Up @@ -751,3 +750,53 @@ def onnx_type_str_to_enum(dtype: str) -> int:
dtype = dtype.split("tensor(")[-1].split(")")[0]
dtype = "FLOAT" if dtype == "float32" else dtype.upper()
return getattr(onnx.TensorProto, dtype)


def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.

This also removes the unused outputs from the training_mode nodes.

Args:
onnx_model: The onnx model.
node_op_type: The node type to remove training_mode attribute from.

Returns:
The onnx model with the training_mode attribute removed.
"""
removed_output_names = set()
all_inputs = {inp for n in onnx_model.graph.node for inp in n.input}
graph_outputs = {o.name for o in onnx_model.graph.output}
keep = all_inputs | graph_outputs

for node in onnx_model.graph.node:
if node.op_type != node_op_type:
continue

is_training_mode = False
# Drop the 'training_mode' attribute if present
for idx, attr in enumerate(list(node.attribute)):
if attr.name == "training_mode":
del node.attribute[idx]
if attr.i == 1:
is_training_mode = True
break

# If the node has extra outputs, remove them all including the training outputs
if is_training_mode:
to_remove = []
for name in node.output:
if name not in keep:
removed_output_names.add(name)
to_remove.append(name)

for name in to_remove:
node.output.remove(name)

if removed_output_names:
# Clean up corresponding value_info entries
keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names]
del onnx_model.graph.value_info[:]
onnx_model.graph.value_info.extend(keep)

return onnx_model
23 changes: 1 addition & 22 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_node_names,
get_output_names,
get_output_shapes,
remove_node_training_mode,
)
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
Expand Down Expand Up @@ -569,25 +570,3 @@ def get_onnx_bytes(*args, **kwargs) -> bytes:
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
return onnx_bytes_obj.get_onnx_model_file_bytes()


def remove_node_training_mode(onnx_model: ModelProto, node_op_type: str) -> ModelProto:
"""Remove training_mode attribute from selected node type.

Args:
onnx_model: The onnx model.
node_op_type: The node type to remove training_mode attribute from.

Returns:
The onnx model with the training_mode attribute removed.
"""
for node in onnx_model.graph.node:
if node.op_type == node_op_type:
for attribute in node.attribute:
if attribute.name == "training_mode":
if attribute.i == 1:
node.output.remove(node.output[1])
node.output.remove(node.output[1])
attribute.i = 0

return onnx_model
224 changes: 223 additions & 1 deletion tests/unit/onnx/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,29 @@

import os

import numpy as np
import onnx
import pytest
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
from onnx.helper import (
make_graph,
make_model,
make_node,
make_opsetid,
make_tensor,
make_tensor_value_info,
)

from modelopt.onnx.utils import save_onnx_bytes_to_dir, validate_onnx
from modelopt.onnx.utils import (
get_input_names_from_bytes,
get_output_names_from_bytes,
randomize_weights_onnx_bytes,
remove_node_training_mode,
remove_weights_data,
save_onnx_bytes_to_dir,
validate_onnx,
)
from modelopt.torch._deploy.utils import get_onnx_bytes


@pytest.mark.parametrize(
Expand All @@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes):
def test_save_onnx(tmp_path):
save_onnx_bytes_to_dir(b"test_onnx_bytes", tmp_path, "test")
assert os.path.exists(os.path.join(tmp_path, "test.onnx"))


def make_onnx_model_for_matmul_op():
input_left = np.array([1, 2])
input_right = np.array([1, 3])
output_shape = np.matmul(input_left, input_right).shape
node = make_node("MatMul", ["X", "Y"], ["Z"], name="matmul")
graph = make_graph(
[node],
"test_graph",
[
make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape),
make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape),
],
[make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)],
)
model = make_model(graph, producer_name="Omniengine Tester")
return model.SerializeToString()


def test_input_names():
model_bytes = make_onnx_model_for_matmul_op()
input_names = get_input_names_from_bytes(model_bytes)
assert input_names == ["X", "Y"]


def test_output_names():
model_bytes = make_onnx_model_for_matmul_op()
output_names = get_output_names_from_bytes(model_bytes)
assert output_names == ["Z"]


def _get_avg_var_of_weights(model):
inits = model.graph.initializer
avg_var_dict = {}

for init in inits:
if len(init.dims) > 1:
dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type)
if dtype in ["float16", "float32", "float64"]:
np_tensor = np.frombuffer(init.raw_data, dtype=dtype)
avg_var_dict[init.name + "_avg"] = np.average(np_tensor)
avg_var_dict[init.name + "_var"] = np.var(np_tensor)

return avg_var_dict


def test_random_onnx_weights():
model, args, kwargs = get_tiny_resnet_and_input()
assert not kwargs

onnx_bytes = get_onnx_bytes(model, args)
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
original_model_size = len(onnx_bytes)

onnx_bytes = remove_weights_data(onnx_bytes)
# Removed model weights should be greater than 18 MB
assert original_model_size - len(onnx_bytes) > 18e6

# After assigning random weights, model size should be slightly greater than the the original
# size due to some extra metadata
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
assert len(onnx_bytes) > original_model_size

randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
for key, value in original_avg_var_dict.items():
assert abs(value - randomized_avg_var_dict[key]) < 0.1


def test_reproducible_random_weights():
model, args, kwargs = get_tiny_resnet_and_input()
assert not kwargs

original_onnx_bytes = get_onnx_bytes(model, args)
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)

# Check if the randomization produces the same weights
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
assert onnx_bytes_1 == onnx_bytes_2


def _make_bn_initializer(name: str, shape, value=1.0):
"""Helper to create an initializer tensor for BatchNorm."""
data = np.full(shape, value, dtype=np.float32)
return make_tensor(name, onnx.TensorProto.FLOAT, shape, data.flatten())


def _make_batchnorm_model(bn_node, extra_value_infos=None):
"""Helper to create an ONNX model with a BatchNormalization node.

The created model has the following schematic structure:

graph name: "test_graph"
inputs:
- input: FLOAT [1, 3, 224, 224]
initializers:
- scale: FLOAT [3]
- bias: FLOAT [3]
- mean: FLOAT [3]
- var: FLOAT [3]
nodes:
- BatchNormalization (name comes from `bn_node`), with:
inputs = ["input", "scale", "bias", "mean", "var"]
outputs = as provided by `bn_node` (e.g., ["output"], or
["output", "running_mean", "running_var", "saved_mean"])
outputs:
- output: FLOAT [1, 3, 224, 224]

If `extra_value_infos` is provided (e.g., value_info for non-training outputs
like "running_mean"/"running_var" and/or training-only outputs like
"saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
that prune training-only outputs and their value_info entries, while keeping
regular outputs such as "running_mean" and "running_var" intact.
"""
initializers = [
_make_bn_initializer("scale", [3], 1.0),
_make_bn_initializer("bias", [3], 0.0),
_make_bn_initializer("mean", [3], 0.0),
_make_bn_initializer("var", [3], 1.0),
]

graph_outputs = []
for output_name, shape in [
("output", [1, 3, 224, 224]),
("running_mean", [3]),
("running_var", [3]),
]:
if output_name in bn_node.output:
graph_outputs.append(make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape))

graph_def = make_graph(
[bn_node],
"test_graph",
[make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
graph_outputs,
initializer=initializers,
value_info=extra_value_infos or [],
)

return make_model(graph_def, opset_imports=[make_opsetid("", 14)])


def test_remove_node_training_mode_attribute():
"""Test removal of training_mode attribute from BatchNormalization nodes."""
bn_node = make_node(
"BatchNormalization",
inputs=["input", "scale", "bias", "mean", "var"],
outputs=["output"],
name="bn1",
training_mode=1, # This attribute should be removed
)

model = _make_batchnorm_model(bn_node)
result_model = remove_node_training_mode(model, "BatchNormalization")

bn_node_result = result_model.graph.node[0]
assert bn_node_result.op_type == "BatchNormalization"

# Check that training_mode attribute is not present
attr_names = [attr.name for attr in bn_node_result.attribute]
assert "training_mode" not in attr_names


def test_remove_node_extra_training_outputs():
"""Test removal of extra training outputs from BatchNormalization nodes."""
bn_node = make_node(
"BatchNormalization",
inputs=["input", "scale", "bias", "mean", "var"],
outputs=[
"output",
"running_mean",
"running_var",
"saved_mean",
"saved_inv_std",
],
name="bn1",
training_mode=1,
)

# Extra training outputs are attached to the graph's value_info
value_infos = [
make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]),
make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]),
]

model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos)
result_model = remove_node_training_mode(model, "BatchNormalization")

# Verify only the non-training outputs remain
bn_node_result = result_model.graph.node[0]
print(bn_node_result.output)
assert len(bn_node_result.output) == 3
assert bn_node_result.output[0] == "output"
assert bn_node_result.output[1] == "running_mean"
assert bn_node_result.output[2] == "running_var"

# Verify value_info entries for removed outputs are cleaned up
value_info_names = [vi.name for vi in result_model.graph.value_info]
assert "saved_mean" not in value_info_names
assert "saved_inv_std" not in value_info_names
Loading
Loading