@@ -72,7 +72,7 @@ def test_conv(conv):
7272 layer = conv (irreps_in , irreps_in , irreps_sh , edge_attr_dim = edge_attr_dim )
7373
7474 out_1 , out_2 = apply_layer_rotation (layer )
75- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
75+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
7676
7777
7878@pytest .mark .parametrize ("conv" , CONV_LAYERS )
@@ -89,7 +89,7 @@ def test_gated_conv(conv):
8989 layer = Gated (wrapped , irreps_in = irreps_in , irreps_out = irreps_in )
9090
9191 out_1 , out_2 = apply_layer_rotation (layer )
92- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
92+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
9393
9494
9595@pytest .mark .parametrize ("conv" , CONV_LAYERS )
@@ -108,7 +108,7 @@ def test_conv_block(conv):
108108 )
109109
110110 out_1 , out_2 = apply_layer_rotation (layer )
111- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
111+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
112112
113113
114114@pytest .mark .parametrize ("conv" , CONV_LAYERS )
@@ -132,7 +132,7 @@ def test_attention(conv):
132132 )
133133
134134 out_1 , out_2 = apply_layer_rotation (layer )
135- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
135+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
136136
137137
138138@pytest .mark .parametrize ("conv" , [Conv , SeparableConv ])
@@ -158,7 +158,7 @@ def test_multihead_attention(conv):
158158 )
159159
160160 out_1 , out_2 = apply_layer_rotation (layer )
161- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
161+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
162162
163163
164164def test_layer_norm ():
@@ -174,7 +174,7 @@ def test_layer_norm():
174174 out_1 = layer (x @ D .T )
175175 out_2 = layer (x ) @ D .T
176176
177- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
177+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
178178
179179
180180def test_equivariant_mlp ():
@@ -194,7 +194,7 @@ def test_equivariant_mlp():
194194 out_1 = layer (x @ D .T )
195195 out_2 = layer (x ) @ D .T
196196
197- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
197+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
198198
199199
200200def test_transformer ():
@@ -214,4 +214,4 @@ def test_transformer():
214214 )
215215
216216 out_1 , out_2 = apply_layer_rotation (layer )
217- assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
217+ torch .testing . assert_close (out_1 , out_2 , atol = 1e-10 , rtol = 1e-10 )
0 commit comments