Skip to content

Commit 082ed96

Browse files
committed
add a test
1 parent 4bae1d0 commit 082ed96

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
6969
else:
7070
original_bias = np.zeros_like(input_mean)
7171
# Use inbound input 1 (should be weight) to derive a name for the bias
72-
# to avoid name collision on initializer creation.
72+
# to avoid name collision on initializer creation when there are multiple patterns
73+
# sharing the same parent nodes.
7374
bias_name = inbound_node.inputs[1].name + "_bias"
7475
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
7576

onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,64 @@ def test_fuse_batchnorm_graph_inputs(self):
253253
# No changes were applied as W is a graph input
254254
self.assertEqual(count, 0)
255255

256+
def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self):
257+
model_proto = onnx.parser.parse_model("""
258+
< ir_version: 7, opset_import: ["" : 17] >
259+
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y1, float [N, ?, ?, ?] Y2)
260+
{
261+
X1 = MaxPool<kernel_shape=[3,3]>(X)
262+
X2 = Conv(X1, W1)
263+
Y1 = BatchNormalization(X2, gamma_64, beta_64, input_mean_64, input_var_64)
264+
X3 = Conv(X1, W2)
265+
Y2 = BatchNormalization(X3, gamma_256, beta_256, input_mean_256, input_var_256)
266+
}
267+
""")
268+
initializers = [
269+
onnx.numpy_helper.from_array(
270+
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W1"
271+
),
272+
onnx.numpy_helper.from_array(
273+
np.random.randn(64).astype(np.float32), name="gamma_64"
274+
),
275+
onnx.numpy_helper.from_array(
276+
np.random.randn(64).astype(np.float32), name="beta_64"
277+
),
278+
onnx.numpy_helper.from_array(
279+
np.random.randn(64).astype(np.float32), name="input_mean_64"
280+
),
281+
onnx.numpy_helper.from_array(
282+
np.abs(np.random.randn(64)).astype(np.float32), name="input_var_64"
283+
),
284+
onnx.numpy_helper.from_array(
285+
np.random.randn(256, 32, 3, 3).astype(np.float32), name="W2"
286+
),
287+
onnx.numpy_helper.from_array(
288+
np.random.randn(256).astype(np.float32), name="gamma_256"
289+
),
290+
onnx.numpy_helper.from_array(
291+
np.random.randn(256).astype(np.float32), name="beta_256"
292+
),
293+
onnx.numpy_helper.from_array(
294+
np.random.randn(256).astype(np.float32), name="input_mean_256"
295+
),
296+
onnx.numpy_helper.from_array(
297+
np.abs(np.random.randn(256)).astype(np.float32), name="input_var_256"
298+
),
299+
]
300+
model_proto.graph.initializer.extend(initializers)
301+
onnx.checker.check_model(model_proto, True)
302+
model = ir.serde.deserialize_model(model_proto)
303+
count = _fuse_batchnorm.rules.apply_to_model(model)
304+
305+
# Applied twice, once for each BatchNorm
306+
self.assertEqual(count, 2)
307+
# it should have different bias names for the two fused Conv nodes
308+
conv_nodes = [node for node in model.graph if node.op_type == "Conv"]
309+
self.assertEqual(len(conv_nodes), 2)
310+
bias_names_1 = conv_nodes[0].inputs[2].name
311+
bias_names_2 = conv_nodes[1].inputs[2].name
312+
self.assertNotEqual(bias_names_1, bias_names_2)
313+
256314

257315
if __name__ == "__main__":
258316
unittest.main()

0 commit comments

Comments
 (0)