Skip to content

Commit 04b8496

Browse files
committed
Traning mode removal from ONNX nodes
Signed-off-by: Riyad Islam <[email protected]>
1 parent c391942 commit 04b8496

File tree

3 files changed

+132
-26
lines changed

3 files changed

+132
-26
lines changed

modelopt/onnx/utils.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import numpy as np
2626
import onnx
2727
import onnx_graphsurgeon as gs
28-
from onnx import TensorProto, ValueInfoProto, numpy_helper
2928
from onnx.helper import get_attribute_value
3029
from onnx_graphsurgeon import Constant, Node, Variable
3130

@@ -289,7 +288,7 @@ def _convert_types_to_np(types: dict[str, int] | list[int] | int) -> Any:
289288

290289
def get_tensor_by_name(
291290
onnx_model: onnx.ModelProto, tensor_name: str
292-
) -> ValueInfoProto | TensorProto | None:
291+
) -> onnx.ValueInfoProto | onnx.TensorProto | None:
293292
"""This function returns a tensor from its name.
294293
295294
This function searches for a tensor in the model's:
@@ -438,7 +437,7 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes:
438437
numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype(
439438
dtype
440439
)
441-
tensor = numpy_helper.from_array(numpy_array, init.name)
440+
tensor = onnx.numpy_helper.from_array(numpy_array, init.name)
442441
model.graph.initializer[idx].CopyFrom(tensor)
443442

444443
buffer = io.BytesIO()
@@ -751,3 +750,46 @@ def onnx_type_str_to_enum(dtype: str) -> int:
751750
dtype = dtype.split("tensor(")[-1].split(")")[0]
752751
dtype = "FLOAT" if dtype == "float32" else dtype.upper()
753752
return getattr(onnx.TensorProto, dtype)
753+
754+
755+
def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
756+
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
757+
758+
Args:
759+
onnx_model: The onnx model.
760+
node_op_type: The node type to remove training_mode attribute from.
761+
762+
Returns:
763+
The onnx model with the training_mode attribute removed.
764+
"""
765+
removed_output_names = set()
766+
767+
for node in onnx_model.graph.node:
768+
if node.op_type != node_op_type:
769+
continue
770+
771+
# Drop the 'training_mode' attribute if present
772+
for idx, attr in enumerate(list(node.attribute)):
773+
if attr.name == "training_mode":
774+
del node.attribute[idx]
775+
break
776+
777+
# If node has extra training outputs, keep only the first
778+
if len(node.output) > 1:
779+
removed_output_names.update(node.output[1:])
780+
node.output[:] = node.output[:1]
781+
782+
if removed_output_names:
783+
# Clean up corresponding value_info entries
784+
keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names]
785+
del onnx_model.graph.value_info[:]
786+
onnx_model.graph.value_info.extend(keep)
787+
788+
# Also clean up graph.output entries
789+
keep_outputs = [
790+
out for out in onnx_model.graph.output if out.name not in removed_output_names
791+
]
792+
del onnx_model.graph.output[:]
793+
onnx_model.graph.output.extend(keep_outputs)
794+
795+
return onnx_model

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
get_node_names,
4646
get_output_names,
4747
get_output_shapes,
48+
remove_node_training_mode,
4849
)
4950
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5051
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
@@ -569,25 +570,3 @@ def get_onnx_bytes(*args, **kwargs) -> bytes:
569570
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
570571
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
571572
return onnx_bytes_obj.get_onnx_model_file_bytes()
572-
573-
574-
def remove_node_training_mode(onnx_model: ModelProto, node_op_type: str) -> ModelProto:
575-
"""Remove training_mode attribute from selected node type.
576-
577-
Args:
578-
onnx_model: The onnx model.
579-
node_op_type: The node type to remove training_mode attribute from.
580-
581-
Returns:
582-
The onnx model with the training_mode attribute removed.
583-
"""
584-
for node in onnx_model.graph.node:
585-
if node.op_type == node_op_type:
586-
for attribute in node.attribute:
587-
if attribute.name == "training_mode":
588-
if attribute.i == 1:
589-
node.output.remove(node.output[1])
590-
node.output.remove(node.output[1])
591-
attribute.i = 0
592-
593-
return onnx_model

tests/unit/torch/deploy/utils/test_torch_onnx_utils.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,21 @@
2424
import torch.nn as nn
2525
from _test_utils.torch_model.deploy_models import BaseDeployModel, get_deploy_models
2626
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
27-
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
27+
from onnx.helper import (
28+
make_graph,
29+
make_model,
30+
make_node,
31+
make_opsetid,
32+
make_tensor,
33+
make_tensor_value_info,
34+
)
2835

2936
from modelopt.onnx.utils import (
3037
get_batch_size_from_bytes,
3138
get_input_names_from_bytes,
3239
get_output_names_from_bytes,
3340
randomize_weights_onnx_bytes,
41+
remove_node_training_mode,
3442
remove_weights_data,
3543
validate_batch_size,
3644
)
@@ -255,3 +263,80 @@ def test_reproducible_random_weights():
255263
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
256264
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
257265
assert onnx_bytes_1 == onnx_bytes_2
266+
267+
268+
def _make_bn_initializer(name: str, shape, value=1.0):
269+
"""Helper to create an initializer tensor for BatchNorm."""
270+
data = np.full(shape, value, dtype=np.float32)
271+
return make_tensor(name, onnx.TensorProto.FLOAT, shape, data.flatten())
272+
273+
274+
def _make_batchnorm_model(bn_node, extra_value_infos=None):
275+
"""Helper to create an ONNX model with a BatchNormalization node."""
276+
initializers = [
277+
_make_bn_initializer("scale", [3], 1.0),
278+
_make_bn_initializer("bias", [3], 0.0),
279+
_make_bn_initializer("mean", [3], 0.0),
280+
_make_bn_initializer("var", [3], 1.0),
281+
]
282+
283+
graph_def = make_graph(
284+
[bn_node],
285+
"test_graph",
286+
[make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
287+
[make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
288+
initializer=initializers,
289+
value_info=extra_value_infos or [],
290+
)
291+
292+
return make_model(graph_def, opset_imports=[make_opsetid("", 14)])
293+
294+
295+
def test_remove_node_training_mode_attribute():
296+
"""Test removal of training_mode attribute from BatchNormalization nodes."""
297+
bn_node = make_node(
298+
"BatchNormalization",
299+
inputs=["input", "scale", "bias", "mean", "var"],
300+
outputs=["output"],
301+
name="bn1",
302+
training_mode=1, # This attribute should be removed
303+
)
304+
305+
model = _make_batchnorm_model(bn_node)
306+
result_model = remove_node_training_mode(model, "BatchNormalization")
307+
308+
bn_node_result = result_model.graph.node[0]
309+
assert bn_node_result.op_type == "BatchNormalization"
310+
311+
# Check that training_mode attribute is not present
312+
attr_names = [attr.name for attr in bn_node_result.attribute]
313+
assert "training_mode" not in attr_names
314+
315+
316+
def test_remove_node_extra_training_outputs():
317+
"""Test removal of extra training outputs from BatchNormalization nodes."""
318+
bn_node = make_node(
319+
"BatchNormalization",
320+
inputs=["input", "scale", "bias", "mean", "var"],
321+
outputs=["output", "saved_mean", "saved_inv_std"], # Extra training outputs
322+
name="bn1",
323+
training_mode=1,
324+
)
325+
326+
value_infos = [
327+
make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]),
328+
make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]),
329+
]
330+
331+
model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos)
332+
result_model = remove_node_training_mode(model, "BatchNormalization")
333+
334+
# Verify only first output remains
335+
bn_node_result = result_model.graph.node[0]
336+
assert len(bn_node_result.output) == 1
337+
assert bn_node_result.output[0] == "output"
338+
339+
# Verify value_info entries for removed outputs are cleaned up
340+
value_info_names = [vi.name for vi in result_model.graph.value_info]
341+
assert "saved_mean" not in value_info_names
342+
assert "saved_inv_std" not in value_info_names

0 commit comments

Comments
 (0)