Skip to content

Commit d0372f4

Browse files
Training mode removal from ONNX nodes (#277)
Signed-off-by: Riyad Islam <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent 6ec8cdc commit d0372f4

File tree

4 files changed

+278
-117
lines changed

4 files changed

+278
-117
lines changed

modelopt/onnx/utils.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import numpy as np
2626
import onnx
2727
import onnx_graphsurgeon as gs
28-
from onnx import TensorProto, ValueInfoProto, numpy_helper
2928
from onnx.helper import get_attribute_value
3029
from onnx_graphsurgeon import Constant, Node, Variable
3130

@@ -289,7 +288,7 @@ def _convert_types_to_np(types: dict[str, int] | list[int] | int) -> Any:
289288

290289
def get_tensor_by_name(
291290
onnx_model: onnx.ModelProto, tensor_name: str
292-
) -> ValueInfoProto | TensorProto | None:
291+
) -> onnx.ValueInfoProto | onnx.TensorProto | None:
293292
"""This function returns a tensor from its name.
294293
295294
This function searches for a tensor in the model's:
@@ -438,7 +437,7 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes:
438437
numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype(
439438
dtype
440439
)
441-
tensor = numpy_helper.from_array(numpy_array, init.name)
440+
tensor = onnx.numpy_helper.from_array(numpy_array, init.name)
442441
model.graph.initializer[idx].CopyFrom(tensor)
443442

444443
buffer = io.BytesIO()
@@ -751,3 +750,53 @@ def onnx_type_str_to_enum(dtype: str) -> int:
751750
dtype = dtype.split("tensor(")[-1].split(")")[0]
752751
dtype = "FLOAT" if dtype == "float32" else dtype.upper()
753752
return getattr(onnx.TensorProto, dtype)
753+
754+
755+
def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
756+
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
757+
758+
This also removes the unused outputs from the training_mode nodes.
759+
760+
Args:
761+
onnx_model: The onnx model.
762+
node_op_type: The node type to remove training_mode attribute from.
763+
764+
Returns:
765+
The onnx model with the training_mode attribute removed.
766+
"""
767+
removed_output_names = set()
768+
all_inputs = {inp for n in onnx_model.graph.node for inp in n.input}
769+
graph_outputs = {o.name for o in onnx_model.graph.output}
770+
keep = all_inputs | graph_outputs
771+
772+
for node in onnx_model.graph.node:
773+
if node.op_type != node_op_type:
774+
continue
775+
776+
is_training_mode = False
777+
# Drop the 'training_mode' attribute if present
778+
for idx, attr in enumerate(list(node.attribute)):
779+
if attr.name == "training_mode":
780+
del node.attribute[idx]
781+
if attr.i == 1:
782+
is_training_mode = True
783+
break
784+
785+
# If the node has extra outputs, remove them all including the training outputs
786+
if is_training_mode:
787+
to_remove = []
788+
for name in node.output:
789+
if name not in keep:
790+
removed_output_names.add(name)
791+
to_remove.append(name)
792+
793+
for name in to_remove:
794+
node.output.remove(name)
795+
796+
if removed_output_names:
797+
# Clean up corresponding value_info entries
798+
keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names]
799+
del onnx_model.graph.value_info[:]
800+
onnx_model.graph.value_info.extend(keep)
801+
802+
return onnx_model

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
get_node_names,
4646
get_output_names,
4747
get_output_shapes,
48+
remove_node_training_mode,
4849
)
4950
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5051
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
@@ -569,25 +570,3 @@ def get_onnx_bytes(*args, **kwargs) -> bytes:
569570
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
570571
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
571572
return onnx_bytes_obj.get_onnx_model_file_bytes()
572-
573-
574-
def remove_node_training_mode(onnx_model: ModelProto, node_op_type: str) -> ModelProto:
575-
"""Remove training_mode attribute from selected node type.
576-
577-
Args:
578-
onnx_model: The onnx model.
579-
node_op_type: The node type to remove training_mode attribute from.
580-
581-
Returns:
582-
The onnx model with the training_mode attribute removed.
583-
"""
584-
for node in onnx_model.graph.node:
585-
if node.op_type == node_op_type:
586-
for attribute in node.attribute:
587-
if attribute.name == "training_mode":
588-
if attribute.i == 1:
589-
node.output.remove(node.output[1])
590-
node.output.remove(node.output[1])
591-
attribute.i = 0
592-
593-
return onnx_model

tests/unit/onnx/test_onnx_utils.py

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,29 @@
1515

1616
import os
1717

18+
import numpy as np
19+
import onnx
1820
import pytest
21+
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
22+
from onnx.helper import (
23+
make_graph,
24+
make_model,
25+
make_node,
26+
make_opsetid,
27+
make_tensor,
28+
make_tensor_value_info,
29+
)
1930

20-
from modelopt.onnx.utils import save_onnx_bytes_to_dir, validate_onnx
31+
from modelopt.onnx.utils import (
32+
get_input_names_from_bytes,
33+
get_output_names_from_bytes,
34+
randomize_weights_onnx_bytes,
35+
remove_node_training_mode,
36+
remove_weights_data,
37+
save_onnx_bytes_to_dir,
38+
validate_onnx,
39+
)
40+
from modelopt.torch._deploy.utils import get_onnx_bytes
2141

2242

2343
@pytest.mark.parametrize(
@@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes):
3151
def test_save_onnx(tmp_path):
3252
save_onnx_bytes_to_dir(b"test_onnx_bytes", tmp_path, "test")
3353
assert os.path.exists(os.path.join(tmp_path, "test.onnx"))
54+
55+
56+
def make_onnx_model_for_matmul_op():
57+
input_left = np.array([1, 2])
58+
input_right = np.array([1, 3])
59+
output_shape = np.matmul(input_left, input_right).shape
60+
node = make_node("MatMul", ["X", "Y"], ["Z"], name="matmul")
61+
graph = make_graph(
62+
[node],
63+
"test_graph",
64+
[
65+
make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape),
66+
make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape),
67+
],
68+
[make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)],
69+
)
70+
model = make_model(graph, producer_name="Omniengine Tester")
71+
return model.SerializeToString()
72+
73+
74+
def test_input_names():
75+
model_bytes = make_onnx_model_for_matmul_op()
76+
input_names = get_input_names_from_bytes(model_bytes)
77+
assert input_names == ["X", "Y"]
78+
79+
80+
def test_output_names():
81+
model_bytes = make_onnx_model_for_matmul_op()
82+
output_names = get_output_names_from_bytes(model_bytes)
83+
assert output_names == ["Z"]
84+
85+
86+
def _get_avg_var_of_weights(model):
87+
inits = model.graph.initializer
88+
avg_var_dict = {}
89+
90+
for init in inits:
91+
if len(init.dims) > 1:
92+
dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type)
93+
if dtype in ["float16", "float32", "float64"]:
94+
np_tensor = np.frombuffer(init.raw_data, dtype=dtype)
95+
avg_var_dict[init.name + "_avg"] = np.average(np_tensor)
96+
avg_var_dict[init.name + "_var"] = np.var(np_tensor)
97+
98+
return avg_var_dict
99+
100+
101+
def test_random_onnx_weights():
102+
model, args, kwargs = get_tiny_resnet_and_input()
103+
assert not kwargs
104+
105+
onnx_bytes = get_onnx_bytes(model, args)
106+
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
107+
original_model_size = len(onnx_bytes)
108+
109+
onnx_bytes = remove_weights_data(onnx_bytes)
110+
# Removed model weights should be greater than 18 MB
111+
assert original_model_size - len(onnx_bytes) > 18e6
112+
113+
# After assigning random weights, model size should be slightly greater than the the original
114+
# size due to some extra metadata
115+
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
116+
assert len(onnx_bytes) > original_model_size
117+
118+
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
119+
for key, value in original_avg_var_dict.items():
120+
assert abs(value - randomized_avg_var_dict[key]) < 0.1
121+
122+
123+
def test_reproducible_random_weights():
124+
model, args, kwargs = get_tiny_resnet_and_input()
125+
assert not kwargs
126+
127+
original_onnx_bytes = get_onnx_bytes(model, args)
128+
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)
129+
130+
# Check if the randomization produces the same weights
131+
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
132+
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
133+
assert onnx_bytes_1 == onnx_bytes_2
134+
135+
136+
def _make_bn_initializer(name: str, shape, value=1.0):
137+
"""Helper to create an initializer tensor for BatchNorm."""
138+
data = np.full(shape, value, dtype=np.float32)
139+
return make_tensor(name, onnx.TensorProto.FLOAT, shape, data.flatten())
140+
141+
142+
def _make_batchnorm_model(bn_node, extra_value_infos=None):
143+
"""Helper to create an ONNX model with a BatchNormalization node.
144+
145+
The created model has the following schematic structure:
146+
147+
graph name: "test_graph"
148+
inputs:
149+
- input: FLOAT [1, 3, 224, 224]
150+
initializers:
151+
- scale: FLOAT [3]
152+
- bias: FLOAT [3]
153+
- mean: FLOAT [3]
154+
- var: FLOAT [3]
155+
nodes:
156+
- BatchNormalization (name comes from `bn_node`), with:
157+
inputs = ["input", "scale", "bias", "mean", "var"]
158+
outputs = as provided by `bn_node` (e.g., ["output"], or
159+
["output", "running_mean", "running_var", "saved_mean"])
160+
outputs:
161+
- output: FLOAT [1, 3, 224, 224]
162+
163+
If `extra_value_infos` is provided (e.g., value_info for non-training outputs
164+
like "running_mean"/"running_var" and/or training-only outputs like
165+
"saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
166+
Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
167+
that prune training-only outputs and their value_info entries, while keeping
168+
regular outputs such as "running_mean" and "running_var" intact.
169+
"""
170+
initializers = [
171+
_make_bn_initializer("scale", [3], 1.0),
172+
_make_bn_initializer("bias", [3], 0.0),
173+
_make_bn_initializer("mean", [3], 0.0),
174+
_make_bn_initializer("var", [3], 1.0),
175+
]
176+
177+
graph_outputs = []
178+
for output_name, shape in [
179+
("output", [1, 3, 224, 224]),
180+
("running_mean", [3]),
181+
("running_var", [3]),
182+
]:
183+
if output_name in bn_node.output:
184+
graph_outputs.append(make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape))
185+
186+
graph_def = make_graph(
187+
[bn_node],
188+
"test_graph",
189+
[make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
190+
graph_outputs,
191+
initializer=initializers,
192+
value_info=extra_value_infos or [],
193+
)
194+
195+
return make_model(graph_def, opset_imports=[make_opsetid("", 14)])
196+
197+
198+
def test_remove_node_training_mode_attribute():
199+
"""Test removal of training_mode attribute from BatchNormalization nodes."""
200+
bn_node = make_node(
201+
"BatchNormalization",
202+
inputs=["input", "scale", "bias", "mean", "var"],
203+
outputs=["output"],
204+
name="bn1",
205+
training_mode=1, # This attribute should be removed
206+
)
207+
208+
model = _make_batchnorm_model(bn_node)
209+
result_model = remove_node_training_mode(model, "BatchNormalization")
210+
211+
bn_node_result = result_model.graph.node[0]
212+
assert bn_node_result.op_type == "BatchNormalization"
213+
214+
# Check that training_mode attribute is not present
215+
attr_names = [attr.name for attr in bn_node_result.attribute]
216+
assert "training_mode" not in attr_names
217+
218+
219+
def test_remove_node_extra_training_outputs():
220+
"""Test removal of extra training outputs from BatchNormalization nodes."""
221+
bn_node = make_node(
222+
"BatchNormalization",
223+
inputs=["input", "scale", "bias", "mean", "var"],
224+
outputs=[
225+
"output",
226+
"running_mean",
227+
"running_var",
228+
"saved_mean",
229+
"saved_inv_std",
230+
],
231+
name="bn1",
232+
training_mode=1,
233+
)
234+
235+
# Extra training outputs are attached to the graph's value_info
236+
value_infos = [
237+
make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]),
238+
make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]),
239+
]
240+
241+
model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos)
242+
result_model = remove_node_training_mode(model, "BatchNormalization")
243+
244+
# Verify only the non-training outputs remain
245+
bn_node_result = result_model.graph.node[0]
246+
print(bn_node_result.output)
247+
assert len(bn_node_result.output) == 3
248+
assert bn_node_result.output[0] == "output"
249+
assert bn_node_result.output[1] == "running_mean"
250+
assert bn_node_result.output[2] == "running_var"
251+
252+
# Verify value_info entries for removed outputs are cleaned up
253+
value_info_names = [vi.name for vi in result_model.graph.value_info]
254+
assert "saved_mean" not in value_info_names
255+
assert "saved_inv_std" not in value_info_names

0 commit comments

Comments
 (0)