Skip to content

Commit b9aa094

Browse files
author
Chris Elion
authored
onnx: export model constants as outputs, instead of in the initializer list (#4073)
* export model constants as outputs, instead of in the initializer list * changelog
1 parent 03e488b commit b9aa094

File tree

2 files changed

+5
-29
lines changed

2 files changed

+5
-29
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
4040
### Bug Fixes
4141
- Fixed an issue where SAC would perform too many model updates when resuming from a
4242
checkpoint, and too few when using `buffer_init_steps`. (#4038)
43+
- Fixed a bug in the onnx export that would cause constants needed for inference to not be visible to some versions of
44+
the Barracuda importer. (#4073)
4345
#### com.unity.ml-agents (C#)
4446
#### ml-agents / ml-agents-envs / gym-unity (Python)
4547

ml-agents/mlagents/model_serialization.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from distutils.version import LooseVersion
55

66
try:
7-
import onnx
87
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
98
from tf2onnx import optimizer
109

@@ -126,16 +125,6 @@ def convert_frozen_to_onnx(
126125
) -> Any:
127126
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
128127

129-
# Some constants in the graph need to be read by the inference system.
130-
# These aren't used by the model anywhere, so trying to make sure they propagate
131-
# through conversion and import is a losing battle. Instead, save them now,
132-
# so that we can add them back later.
133-
constant_values = {}
134-
for n in frozen_graph_def.node:
135-
if n.name in MODEL_CONSTANTS:
136-
val = n.attr["value"].tensor.int_val[0]
137-
constant_values[n.name] = val
138-
139128
inputs = _get_input_node_names(frozen_graph_def)
140129
outputs = _get_output_node_names(frozen_graph_def)
141130
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
@@ -157,26 +146,9 @@ def convert_frozen_to_onnx(
157146
onnx_graph = optimizer.optimize_graph(g)
158147
model_proto = onnx_graph.make_model(settings.brain_name)
159148

160-
# Save the constant values back the graph initializer.
161-
# This will ensure the importer gets them as global constants.
162-
constant_nodes = []
163-
for k, v in constant_values.items():
164-
constant_node = _make_onnx_node_for_constant(k, v)
165-
constant_nodes.append(constant_node)
166-
model_proto.graph.initializer.extend(constant_nodes)
167149
return model_proto
168150

169151

170-
def _make_onnx_node_for_constant(name: str, value: int) -> Any:
171-
tensor_value = onnx.TensorProto(
172-
data_type=onnx.TensorProto.INT32,
173-
name=name,
174-
int32_data=[value],
175-
dims=[1, 1, 1, 1],
176-
)
177-
return tensor_value
178-
179-
180152
def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
181153
"""
182154
Get the list of input node names from the graph.
@@ -201,10 +173,12 @@ def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
201173
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
202174
"""
203175
Get the list of output node names from the graph.
176+
Also include constants, so that they will be readable by the
177+
onnx importer.
204178
Names are suffixed with ":0"
205179
"""
206180
node_names = _get_frozen_graph_node_names(frozen_graph_def)
207-
output_names = node_names & POSSIBLE_OUTPUT_NODES
181+
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS)
208182
# Append the port
209183
return [f"{n}:0" for n in output_names]
210184

0 commit comments

Comments
 (0)