Skip to content

Commit 626d99d

Browse files
committed
Fix
Signed-off-by: Riyad Islam <[email protected]>
1 parent 8675622 commit 626d99d

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

modelopt/onnx/utils.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,8 @@ def onnx_type_str_to_enum(dtype: str) -> int:
755755
def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
756756
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
757757
758+
This also removes the unused outputs from the training_mode nodes.
759+
758760
Args:
759761
onnx_model: The onnx model.
760762
node_op_type: The node type to remove training_mode attribute from.
@@ -763,33 +765,38 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) ->
763765
The onnx model with the training_mode attribute removed.
764766
"""
765767
removed_output_names = set()
768+
all_inputs = {inp for n in onnx_model.graph.node for inp in n.input}
769+
graph_outputs = {o.name for o in onnx_model.graph.output}
770+
keep = all_inputs | graph_outputs
766771

767772
for node in onnx_model.graph.node:
768773
if node.op_type != node_op_type:
769774
continue
770775

776+
is_training_mode = False
771777
# Drop the 'training_mode' attribute if present
772778
for idx, attr in enumerate(list(node.attribute)):
773779
if attr.name == "training_mode":
774780
del node.attribute[idx]
781+
if attr.i == 1:
782+
is_training_mode = True
775783
break
776784

777-
# If node has extra training outputs, keep only the first
778-
if len(node.output) > 1:
779-
removed_output_names.update(node.output[1:])
780-
node.output[:] = node.output[:1]
785+
# If the node has extra outputs, remove them all including the training outputs
786+
if is_training_mode:
787+
to_remove = []
788+
for name in node.output:
789+
if name not in keep:
790+
removed_output_names.add(name)
791+
to_remove.append(name)
792+
793+
for name in to_remove:
794+
node.output.remove(name)
781795

782796
if removed_output_names:
783797
# Clean up corresponding value_info entries
784798
keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names]
785799
del onnx_model.graph.value_info[:]
786800
onnx_model.graph.value_info.extend(keep)
787801

788-
# Also clean up graph.output entries
789-
keep_outputs = [
790-
out for out in onnx_model.graph.output if out.name not in removed_output_names
791-
]
792-
del onnx_model.graph.output[:]
793-
onnx_model.graph.output.extend(keep_outputs)
794-
795802
return onnx_model

tests/unit/torch/deploy/utils/test_torch_onnx_utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,20 @@ def _make_batchnorm_model(bn_node, extra_value_infos=None):
306306
_make_bn_initializer("var", [3], 1.0),
307307
]
308308

309+
graph_outputs = []
310+
for output_name, shape in [
311+
("output", [1, 3, 224, 224]),
312+
("running_mean", [3]),
313+
("running_var", [3]),
314+
]:
315+
if output_name in bn_node.output:
316+
graph_outputs.append(make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape))
317+
309318
graph_def = make_graph(
310319
[bn_node],
311320
"test_graph",
312321
[make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
313-
[make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
322+
graph_outputs,
314323
initializer=initializers,
315324
value_info=extra_value_infos or [],
316325
)
@@ -350,11 +359,12 @@ def test_remove_node_extra_training_outputs():
350359
"running_var",
351360
"saved_mean",
352361
"saved_inv_std",
353-
], # Extra training outputs
362+
],
354363
name="bn1",
355364
training_mode=1,
356365
)
357366

367+
# Extra training outputs are attached to the graph's value_info
358368
value_infos = [
359369
make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]),
360370
make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]),
@@ -363,10 +373,13 @@ def test_remove_node_extra_training_outputs():
363373
model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos)
364374
result_model = remove_node_training_mode(model, "BatchNormalization")
365375

366-
# Verify only first output remains
376+
# Verify only the non-training outputs remain
367377
bn_node_result = result_model.graph.node[0]
368-
assert len(bn_node_result.output) == 1
378+
print(bn_node_result.output)
379+
assert len(bn_node_result.output) == 3
369380
assert bn_node_result.output[0] == "output"
381+
assert bn_node_result.output[2] == "running_var"
382+
assert bn_node_result.output[1] == "running_mean"
370383

371384
# Verify value_info entries for removed outputs are cleaned up
372385
value_info_names = [vi.name for vi in result_model.graph.value_info]

0 commit comments

Comments
 (0)