Skip to content

Commit 7ce78c0

Browse files
authored
Fix nxp unittests (#15384)
Let's not hard code the number of nodes in the graph. The graph of batch norm + conv looks like this after torch nightly 1015: ``` graph(): %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight] %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias] %p_batch_norm_batch_norm_weight : [num_users=1] = placeholder[target=p_batch_norm_batch_norm_weight] %p_batch_norm_batch_norm_bias : [num_users=1] = placeholder[target=p_batch_norm_batch_norm_bias] %b_batch_norm_batch_norm_running_mean : [num_users=1] = placeholder[target=b_batch_norm_batch_norm_running_mean] %b_batch_norm_batch_norm_running_var : [num_users=1] = placeholder[target=b_batch_norm_batch_norm_running_var] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {}) %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_batch_norm_batch_norm_weight, %p_batch_norm_batch_norm_bias, %b_batch_norm_batch_norm_running_mean, %b_batch_norm_batch_norm_running_var, False, 0.1, 1e-05, True), kwargs = {}) return (batch_norm,) ```
1 parent 51b83ff commit 7ce78c0

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
105105
og_nodes = list(program.graph.nodes)
106106
transformed_nodes = list(graph_module_out.graph.nodes)
107107

108-
assert len(og_nodes) == (11 if bias else 10)
109-
assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default"
108+
assert any(
109+
node.op == "call_function" and node.target.__name__ == "batch_norm.default"
110+
for node in og_nodes
111+
)
110112

111-
assert len(transformed_nodes) == 5
112113
assert not any(
113114
node.op == "call_function" and "batch_norm" in node.target.__name__
114115
for node in transformed_nodes
@@ -118,7 +119,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
118119
input_data = torch.randn(input_shape, dtype=torch.float32)
119120
out1 = og_module(input_data).detach().numpy()
120121
out2 = graph_module_out(input_data).detach().numpy()
121-
assert np.allclose(out1, out2, atol=3.0e-7)
122+
torch.testing.assert_close(out1, out2)
122123

123124

124125
@pytest.mark.parametrize(
@@ -139,10 +140,11 @@ def test_batch_norm_linear_fusing(bias: bool):
139140
og_nodes = list(og_module.graph.nodes)
140141
transformed_nodes = list(graph_module_out.graph.nodes)
141142

142-
assert len(og_nodes) == (11 if bias else 10)
143-
assert og_nodes[8 if bias else 7].target.__name__ == "linear.default"
143+
assert any(
144+
node.op == "call_function" and node.target.__name__ == "linear.default"
145+
for node in og_nodes
146+
)
144147

145-
assert len(transformed_nodes) == 5
146148
assert not any(
147149
node.op == "call_function" and "batch_norm" in node.target.__name__
148150
for node in transformed_nodes
@@ -152,7 +154,7 @@ def test_batch_norm_linear_fusing(bias: bool):
152154
input_data = torch.randn(input_shape, dtype=torch.float32)
153155
out1 = og_module(input_data).detach().numpy()
154156
out2 = graph_module_out(input_data).detach().numpy()
155-
assert np.allclose(out1, out2, atol=1.2e-7)
157+
torch.testing.assert_close(out1, out2)
156158

157159

158160
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)