Skip to content

Commit 768c489

Browse files
committed
NXP backend: Fix batch norm tests.
1 parent aa8a8b9 commit 768c489

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
105105
og_nodes = list(program.graph.nodes)
106106
transformed_nodes = list(graph_module_out.graph.nodes)
107107

108-
assert len(og_nodes) == (11 if bias else 10)
109-
assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default"
108+
assert len(og_nodes) == (10 if bias else 9)
109+
assert og_nodes[8 if bias else 7].target.__name__ == "batch_norm.default"
110110

111111
assert len(transformed_nodes) == 5
112112
assert not any(
@@ -139,8 +139,8 @@ def test_batch_norm_linear_fusing(bias: bool):
139139
og_nodes = list(og_module.graph.nodes)
140140
transformed_nodes = list(graph_module_out.graph.nodes)
141141

142-
assert len(og_nodes) == (11 if bias else 10)
143-
assert og_nodes[8 if bias else 7].target.__name__ == "linear.default"
142+
assert len(og_nodes) == (10 if bias else 9)
143+
assert og_nodes[7 if bias else 6].target.__name__ == "linear.default"
144144

145145
assert len(transformed_nodes) == 5
146146
assert not any(

0 commit comments

Comments
 (0)