@@ -105,10 +105,11 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
105105 og_nodes = list (program .graph .nodes )
106106 transformed_nodes = list (graph_module_out .graph .nodes )
107107
108- assert len (og_nodes ) == (11 if bias else 10 )
109- assert og_nodes [9 if bias else 8 ].target .__name__ == "batch_norm.default"
108+ assert any (
109+ node .op == "call_function" and node .target .__name__ == "batch_norm.default"
110+ for node in og_nodes
111+ )
110112
111- assert len (transformed_nodes ) == 5
112113 assert not any (
113114 node .op == "call_function" and "batch_norm" in node .target .__name__
114115 for node in transformed_nodes
@@ -118,7 +119,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
118119 input_data = torch .randn (input_shape , dtype = torch .float32 )
119120 out1 = og_module (input_data ).detach ().numpy ()
120121 out2 = graph_module_out (input_data ).detach ().numpy ()
121- assert np . allclose (out1 , out2 , atol = 3.0e-7 )
122+ torch . testing . assert_close (out1 , out2 )
122123
123124
124125@pytest .mark .parametrize (
@@ -139,10 +140,11 @@ def test_batch_norm_linear_fusing(bias: bool):
139140 og_nodes = list (og_module .graph .nodes )
140141 transformed_nodes = list (graph_module_out .graph .nodes )
141142
142- assert len (og_nodes ) == (11 if bias else 10 )
143- assert og_nodes [8 if bias else 7 ].target .__name__ == "linear.default"
143+ assert any (
144+ node .op == "call_function" and node .target .__name__ == "linear.default"
145+ for node in og_nodes
146+ )
144147
145- assert len (transformed_nodes ) == 5
146148 assert not any (
147149 node .op == "call_function" and "batch_norm" in node .target .__name__
148150 for node in transformed_nodes
@@ -152,7 +154,7 @@ def test_batch_norm_linear_fusing(bias: bool):
152154 input_data = torch .randn (input_shape , dtype = torch .float32 )
153155 out1 = og_module (input_data ).detach ().numpy ()
154156 out2 = graph_module_out (input_data ).detach ().numpy ()
155- assert np . allclose (out1 , out2 , atol = 1.2e-7 )
157+ torch . testing . assert_close (out1 , out2 )
156158
157159
158160@pytest .mark .parametrize (
0 commit comments