@@ -253,6 +253,64 @@ def test_fuse_batchnorm_graph_inputs(self):
253253 # No changes were applied as W is a graph input
254254 self .assertEqual (count , 0 )
255255
256+ def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node (self ):
257+ model_proto = onnx .parser .parse_model ("""
258+ < ir_version: 7, opset_import: ["" : 17] >
259+ test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y1, float [N, ?, ?, ?] Y2)
260+ {
261+ X1 = MaxPool<kernel_shape=[3,3]>(X)
262+ X2 = Conv(X1, W1)
263+ Y1 = BatchNormalization(X2, gamma_64, beta_64, input_mean_64, input_var_64)
264+ X3 = Conv(X1, W2)
265+ Y2 = BatchNormalization(X3, gamma_256, beta_256, input_mean_256, input_var_256)
266+ }
267+ """ )
268+ initializers = [
269+ onnx .numpy_helper .from_array (
270+ np .random .randn (64 , 32 , 3 , 3 ).astype (np .float32 ), name = "W1"
271+ ),
272+ onnx .numpy_helper .from_array (
273+ np .random .randn (64 ).astype (np .float32 ), name = "gamma_64"
274+ ),
275+ onnx .numpy_helper .from_array (
276+ np .random .randn (64 ).astype (np .float32 ), name = "beta_64"
277+ ),
278+ onnx .numpy_helper .from_array (
279+ np .random .randn (64 ).astype (np .float32 ), name = "input_mean_64"
280+ ),
281+ onnx .numpy_helper .from_array (
282+ np .abs (np .random .randn (64 )).astype (np .float32 ), name = "input_var_64"
283+ ),
284+ onnx .numpy_helper .from_array (
285+ np .random .randn (256 , 32 , 3 , 3 ).astype (np .float32 ), name = "W2"
286+ ),
287+ onnx .numpy_helper .from_array (
288+ np .random .randn (256 ).astype (np .float32 ), name = "gamma_256"
289+ ),
290+ onnx .numpy_helper .from_array (
291+ np .random .randn (256 ).astype (np .float32 ), name = "beta_256"
292+ ),
293+ onnx .numpy_helper .from_array (
294+ np .random .randn (256 ).astype (np .float32 ), name = "input_mean_256"
295+ ),
296+ onnx .numpy_helper .from_array (
297+ np .abs (np .random .randn (256 )).astype (np .float32 ), name = "input_var_256"
298+ ),
299+ ]
300+ model_proto .graph .initializer .extend (initializers )
301+ onnx .checker .check_model (model_proto , True )
302+ model = ir .serde .deserialize_model (model_proto )
303+ count = _fuse_batchnorm .rules .apply_to_model (model )
304+
305+ # Applied twice, once for each BatchNorm
306+ self .assertEqual (count , 2 )
307+ # it should have different bias names for the two fused Conv nodes
308+ conv_nodes = [node for node in model .graph if node .op_type == "Conv" ]
309+ self .assertEqual (len (conv_nodes ), 2 )
310+ bias_names_1 = conv_nodes [0 ].inputs [2 ].name
311+ bias_names_2 = conv_nodes [1 ].inputs [2 ].name
312+ self .assertNotEqual (bias_names_1 , bias_names_2 )
313+
256314
257315if __name__ == "__main__" :
258316 unittest .main ()
0 commit comments