@@ -272,7 +272,33 @@ def _make_bn_initializer(name: str, shape, value=1.0):
272272
273273
274274def _make_batchnorm_model (bn_node , extra_value_infos = None ):
275- """Helper to create an ONNX model with a BatchNormalization node."""
275+ """Helper to create an ONNX model with a BatchNormalization node.
276+
277+ The created model has the following schematic structure:
278+
279+ graph name: "test_graph"
280+ inputs:
281+ - input: FLOAT [1, 3, 224, 224]
282+ initializers:
283+ - scale: FLOAT [3]
284+ - bias: FLOAT [3]
285+ - mean: FLOAT [3]
286+ - var: FLOAT [3]
287+ nodes:
288+ - BatchNormalization (name comes from `bn_node`), with:
289+ inputs = ["input", "scale", "bias", "mean", "var"]
290+ outputs = as provided by `bn_node` (e.g., ["output"], or
291+ ["output", "running_mean", "running_var", "saved_mean"])
292+ outputs:
293+ - output: FLOAT [1, 3, 224, 224]
294+
295+ If `extra_value_infos` is provided (e.g., value_info for non-training outputs
296+ like "running_mean"/"running_var" and/or training-only outputs like
297+ "saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
298+ Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
299+ that prune training-only outputs and their value_info entries, while keeping
300+ regular outputs such as "running_mean" and "running_var" intact.
301+ """
276302 initializers = [
277303 _make_bn_initializer ("scale" , [3 ], 1.0 ),
278304 _make_bn_initializer ("bias" , [3 ], 0.0 ),
@@ -318,7 +344,13 @@ def test_remove_node_extra_training_outputs():
318344 bn_node = make_node (
319345 "BatchNormalization" ,
320346 inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
321- outputs = ["output" , "saved_mean" , "saved_inv_std" ], # Extra training outputs
347+ outputs = [
348+ "output" ,
349+ "running_mean" ,
350+ "running_var" ,
351+ "saved_mean" ,
352+ "saved_inv_std" ,
353+ ], # Extra training outputs
322354 name = "bn1" ,
323355 training_mode = 1 ,
324356 )
0 commit comments