@@ -306,11 +306,20 @@ def _make_batchnorm_model(bn_node, extra_value_infos=None):
306
306
_make_bn_initializer ("var" , [3 ], 1.0 ),
307
307
]
308
308
309
+ graph_outputs = []
310
+ for output_name , shape in [
311
+ ("output" , [1 , 3 , 224 , 224 ]),
312
+ ("running_mean" , [3 ]),
313
+ ("running_var" , [3 ]),
314
+ ]:
315
+ if output_name in bn_node .output :
316
+ graph_outputs .append (make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , shape ))
317
+
309
318
graph_def = make_graph (
310
319
[bn_node ],
311
320
"test_graph" ,
312
321
[make_tensor_value_info ("input" , onnx .TensorProto .FLOAT , [1 , 3 , 224 , 224 ])],
313
- [ make_tensor_value_info ( "output" , onnx . TensorProto . FLOAT , [ 1 , 3 , 224 , 224 ])] ,
322
+ graph_outputs ,
314
323
initializer = initializers ,
315
324
value_info = extra_value_infos or [],
316
325
)
@@ -350,11 +359,12 @@ def test_remove_node_extra_training_outputs():
350
359
"running_var" ,
351
360
"saved_mean" ,
352
361
"saved_inv_std" ,
353
- ], # Extra training outputs
362
+ ],
354
363
name = "bn1" ,
355
364
training_mode = 1 ,
356
365
)
357
366
367
+ # Extra training outputs are attached to the graph's value_info
358
368
value_infos = [
359
369
make_tensor_value_info ("saved_mean" , onnx .TensorProto .FLOAT , [3 ]),
360
370
make_tensor_value_info ("saved_inv_std" , onnx .TensorProto .FLOAT , [3 ]),
@@ -363,10 +373,13 @@ def test_remove_node_extra_training_outputs():
363
373
model = _make_batchnorm_model (bn_node , extra_value_infos = value_infos )
364
374
result_model = remove_node_training_mode (model , "BatchNormalization" )
365
375
366
- # Verify only first output remains
376
+ # Verify only the non-training outputs remain
367
377
bn_node_result = result_model .graph .node [0 ]
368
- assert len (bn_node_result .output ) == 1
378
+ print (bn_node_result .output )
379
+ assert len (bn_node_result .output ) == 3
369
380
assert bn_node_result .output [0 ] == "output"
381
+ assert bn_node_result .output [2 ] == "running_var"
382
+ assert bn_node_result .output [1 ] == "running_mean"
370
383
371
384
# Verify value_info entries for removed outputs are cleaned up
372
385
value_info_names = [vi .name for vi in result_model .graph .value_info ]
0 commit comments