Skip to content

Commit 8d614b5

Browse files
committed
Doc and test update
Signed-off-by: Riyad Islam <[email protected]>
1 parent 04b8496 commit 8d614b5

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

tests/unit/torch/deploy/utils/test_torch_onnx_utils.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,33 @@ def _make_bn_initializer(name: str, shape, value=1.0):
272272

273273

274274
def _make_batchnorm_model(bn_node, extra_value_infos=None):
275-
"""Helper to create an ONNX model with a BatchNormalization node."""
275+
"""Helper to create an ONNX model with a BatchNormalization node.
276+
277+
The created model has the following schematic structure:
278+
279+
graph name: "test_graph"
280+
inputs:
281+
- input: FLOAT [1, 3, 224, 224]
282+
initializers:
283+
- scale: FLOAT [3]
284+
- bias: FLOAT [3]
285+
- mean: FLOAT [3]
286+
- var: FLOAT [3]
287+
nodes:
288+
- BatchNormalization (name comes from `bn_node`), with:
289+
inputs = ["input", "scale", "bias", "mean", "var"]
290+
outputs = as provided by `bn_node` (e.g., ["output"], or
291+
["output", "running_mean", "running_var", "saved_mean"])
292+
outputs:
293+
- output: FLOAT [1, 3, 224, 224]
294+
295+
If `extra_value_infos` is provided (e.g., value_info for non-training outputs
296+
like "running_mean"/"running_var" and/or training-only outputs like
297+
"saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
298+
Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
299+
that prune training-only outputs and their value_info entries, while keeping
300+
regular outputs such as "running_mean" and "running_var" intact.
301+
"""
276302
initializers = [
277303
_make_bn_initializer("scale", [3], 1.0),
278304
_make_bn_initializer("bias", [3], 0.0),
@@ -318,7 +344,13 @@ def test_remove_node_extra_training_outputs():
318344
bn_node = make_node(
319345
"BatchNormalization",
320346
inputs=["input", "scale", "bias", "mean", "var"],
321-
outputs=["output", "saved_mean", "saved_inv_std"], # Extra training outputs
347+
outputs=[
348+
"output",
349+
"running_mean",
350+
"running_var",
351+
"saved_mean",
352+
"saved_inv_std",
353+
], # Extra training outputs
322354
name="bn1",
323355
training_mode=1,
324356
)

0 commit comments

Comments
 (0)