@@ -272,7 +272,33 @@ def _make_bn_initializer(name: str, shape, value=1.0):
272
272
273
273
274
274
def _make_batchnorm_model (bn_node , extra_value_infos = None ):
275
- """Helper to create an ONNX model with a BatchNormalization node."""
275
+ """Helper to create an ONNX model with a BatchNormalization node.
276
+
277
+ The created model has the following schematic structure:
278
+
279
+ graph name: "test_graph"
280
+ inputs:
281
+ - input: FLOAT [1, 3, 224, 224]
282
+ initializers:
283
+ - scale: FLOAT [3]
284
+ - bias: FLOAT [3]
285
+ - mean: FLOAT [3]
286
+ - var: FLOAT [3]
287
+ nodes:
288
+ - BatchNormalization (name comes from `bn_node`), with:
289
+ inputs = ["input", "scale", "bias", "mean", "var"]
290
+ outputs = as provided by `bn_node` (e.g., ["output"], or
291
+ ["output", "running_mean", "running_var", "saved_mean"])
292
+ outputs:
293
+ - output: FLOAT [1, 3, 224, 224]
294
+
295
+ If `extra_value_infos` is provided (e.g., value_info for non-training outputs
296
+ like "running_mean"/"running_var" and/or training-only outputs like
297
+ "saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
298
+ Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
299
+ that prune training-only outputs and their value_info entries, while keeping
300
+ regular outputs such as "running_mean" and "running_var" intact.
301
+ """
276
302
initializers = [
277
303
_make_bn_initializer ("scale" , [3 ], 1.0 ),
278
304
_make_bn_initializer ("bias" , [3 ], 0.0 ),
@@ -318,7 +344,13 @@ def test_remove_node_extra_training_outputs():
318
344
bn_node = make_node (
319
345
"BatchNormalization" ,
320
346
inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
321
- outputs = ["output" , "saved_mean" , "saved_inv_std" ], # Extra training outputs
347
+ outputs = [
348
+ "output" ,
349
+ "running_mean" ,
350
+ "running_var" ,
351
+ "saved_mean" ,
352
+ "saved_inv_std" ,
353
+ ], # Extra training outputs
322
354
name = "bn1" ,
323
355
training_mode = 1 ,
324
356
)
0 commit comments