Skip to content

Commit 5b5b8b7

Browse files
committed
assert torch.allclose -> torch.testing.assert_close
1 parent 4bd64ad commit 5b5b8b7

File tree

7 files changed

+36
-29
lines changed

7 files changed

+36
-29
lines changed

tests/test_attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_compile(model_name, causal, request, irreps_in, irreps_sh, edge_attr_di
8787
out = model(node_attr, edge_index, edge_attr, edge_sh, mask=mask)
8888
compiled_out = compiled_model(node_attr, edge_index, edge_attr, edge_sh, mask=mask)
8989

90-
assert torch.allclose(out, compiled_out)
90+
torch.testing.assert_close(out, compiled_out)
9191

9292

9393
@pytest.mark.parametrize("model_name", ["singlehead_attention", "multihead_attention"])
@@ -124,20 +124,22 @@ def test_causal_vs_non_causal_attention(
124124
causal_mask = edge_index[0] <= edge_index[1]
125125
non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None)
126126
causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask)
127-
assert torch.allclose(non_causal_out, causal_out)
127+
torch.testing.assert_close(non_causal_out, causal_out)
128128

129129
# Check that the outputs are the same for the nodes that do not have any causal edges.
130130
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]])
131131
causal_mask = edge_index[0] <= edge_index[1]
132132
non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None)
133133
causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask)
134-
assert not torch.allclose(non_causal_out[:1], causal_out[:1])
135-
assert torch.allclose(non_causal_out[1:], causal_out[1:])
134+
with pytest.raises(AssertionError):
135+
torch.testing.assert_close(non_causal_out[:1], causal_out[:1])
136+
torch.testing.assert_close(non_causal_out[1:], causal_out[1:])
136137

137138
# Check that the outputs are the same for the nodes that do not have any causal edges.
138139
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 4]])
139140
causal_mask = edge_index[0] <= edge_index[1]
140141
non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None)
141142
causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask)
142-
assert not torch.allclose(non_causal_out[:2], causal_out[:2])
143-
assert torch.allclose(non_causal_out[2:], causal_out[2:])
143+
with pytest.raises(AssertionError):
144+
torch.testing.assert_close(non_causal_out[:2], causal_out[:2])
145+
torch.testing.assert_close(non_causal_out[2:], causal_out[2:])

tests/test_equivariance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_conv(conv):
7272
layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim)
7373

7474
out_1, out_2 = apply_layer_rotation(layer)
75-
assert torch.allclose(out_1, out_2, atol=1e-10)
75+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
7676

7777

7878
@pytest.mark.parametrize("conv", CONV_LAYERS)
@@ -89,7 +89,7 @@ def test_gated_conv(conv):
8989
layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in)
9090

9191
out_1, out_2 = apply_layer_rotation(layer)
92-
assert torch.allclose(out_1, out_2, atol=1e-10)
92+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
9393

9494

9595
@pytest.mark.parametrize("conv", CONV_LAYERS)
@@ -108,7 +108,7 @@ def test_conv_block(conv):
108108
)
109109

110110
out_1, out_2 = apply_layer_rotation(layer)
111-
assert torch.allclose(out_1, out_2, atol=1e-10)
111+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
112112

113113

114114
@pytest.mark.parametrize("conv", CONV_LAYERS)
@@ -132,7 +132,7 @@ def test_attention(conv):
132132
)
133133

134134
out_1, out_2 = apply_layer_rotation(layer)
135-
assert torch.allclose(out_1, out_2, atol=1e-10)
135+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
136136

137137

138138
@pytest.mark.parametrize("conv", [Conv, SeparableConv])
@@ -158,7 +158,7 @@ def test_multihead_attention(conv):
158158
)
159159

160160
out_1, out_2 = apply_layer_rotation(layer)
161-
assert torch.allclose(out_1, out_2, atol=1e-10)
161+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
162162

163163

164164
def test_layer_norm():
@@ -174,7 +174,7 @@ def test_layer_norm():
174174
out_1 = layer(x @ D.T)
175175
out_2 = layer(x) @ D.T
176176

177-
assert torch.allclose(out_1, out_2, atol=1e-10)
177+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
178178

179179

180180
def test_equivariant_mlp():
@@ -194,7 +194,7 @@ def test_equivariant_mlp():
194194
out_1 = layer(x @ D.T)
195195
out_2 = layer(x) @ D.T
196196

197-
assert torch.allclose(out_1, out_2, atol=1e-10)
197+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)
198198

199199

200200
def test_transformer():
@@ -214,4 +214,4 @@ def test_transformer():
214214
)
215215

216216
out_1, out_2 = apply_layer_rotation(layer)
217-
assert torch.allclose(out_1, out_2, atol=1e-10)
217+
torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10)

tests/test_extract_irreps.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def test_extract_irreps_simple():
3232

3333
layer = ExtractIrreps(irreps_in, "0e")
3434
output = layer(input)
35-
assert torch.allclose(output, torch.as_tensor([1.0]))
35+
torch.testing.assert_close(output, torch.as_tensor([1.0]))
3636

3737
layer = ExtractIrreps(irreps_in, "1o")
3838
output = layer(input)
39-
assert torch.allclose(output, torch.as_tensor([2.0, 3.0, 4.0]))
39+
torch.testing.assert_close(output, torch.as_tensor([2.0, 3.0, 4.0]))
4040

4141
layer = ExtractIrreps(irreps_in, "2e")
4242
output = layer(input)
43-
assert torch.allclose(output, torch.as_tensor([5.0, 6.0, 7.0, 8.0, 9.0]))
43+
torch.testing.assert_close(output, torch.as_tensor([5.0, 6.0, 7.0, 8.0, 9.0]))
4444

4545

4646
def test_extract_irreps_multiplicity():
@@ -50,8 +50,8 @@ def test_extract_irreps_multiplicity():
5050

5151
layer = ExtractIrreps(irreps_in, "0e")
5252
output = layer(input)
53-
assert torch.allclose(output, torch.as_tensor([1.0, 5.0, 6.0]))
53+
torch.testing.assert_close(output, torch.as_tensor([1.0, 5.0, 6.0]))
5454

5555
layer = ExtractIrreps(irreps_in, "1o")
5656
output = layer(input)
57-
assert torch.allclose(output, torch.as_tensor([2.0, 3.0, 4.0, 7.0, 8.0, 9.0]))
57+
torch.testing.assert_close(output, torch.as_tensor([2.0, 3.0, 4.0, 7.0, 8.0, 9.0]))

tests/test_fused_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module:
8282
out = layer(node_attr, edge_index, edge_attr, edge_sh)
8383
out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh)
8484

85-
assert torch.allclose(out, out_fused, rtol=1e-3)
85+
torch.testing.assert_close(out, out_fused, rtol=1e-3, atol=1e-5)

tests/test_layer_norm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_layer_norm_compiled(irreps_in: str, seed: int, batch_size: int = 8):
4141
output = layer(input)
4242
output_compiled = layer_compiled(input)
4343

44-
assert torch.allclose(output, output_compiled)
44+
torch.testing.assert_close(output, output_compiled)
4545

4646

4747
@pytest.mark.parametrize(
@@ -57,8 +57,13 @@ def test_layer_norm(irreps_in: str):
5757
output = layer(input)
5858

5959
for mul, ir, field in unpack_irreps(output, irreps_in):
60-
sq_norms = field.norm(dim=-1, keepdim=True).pow(2).sum(dim=-1).mean(dim=-1)
60+
sq_norms = (
61+
field.norm(dim=-1, keepdim=True)
62+
.pow(2)
63+
.sum(dim=-1)
64+
.mean(dim=-1, keepdim=True)
65+
)
6166
if ir.l == 0 and ir.p == 1 and mul == 1:
62-
assert torch.allclose(sq_norms, torch.as_tensor([0.0]))
67+
torch.testing.assert_close(sq_norms, torch.as_tensor([0.0]))
6368
else:
64-
assert torch.allclose(sq_norms, torch.as_tensor([1.0]))
69+
torch.testing.assert_close(sq_norms, torch.as_tensor([1.0]))

tests/test_pack_unpack.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_inverse(irreps_in: str, factor: int, batch_size: int = 5):
100100
output = layer(input)
101101
recovered = inv_layer(output)
102102

103-
assert torch.allclose(input, recovered)
103+
torch.testing.assert_close(input, recovered)
104104

105105

106106
@pytest.mark.parametrize(
@@ -116,7 +116,7 @@ def test_axis_to_mul_compiled(irreps_in: str, factor: int, batch_size: int = 5):
116116
layer = AxisToMul(irreps_in, factor)
117117
layer_compiled = torch.compile(layer, fullgraph=True)
118118

119-
assert torch.allclose(layer(input), layer_compiled(input))
119+
torch.testing.assert_close(layer(input), layer_compiled(input))
120120

121121

122122
@pytest.mark.parametrize(
@@ -132,4 +132,4 @@ def test_mul_to_axis_compiled(irreps_in: str, factor: int, batch_size: int = 5):
132132
layer = MulToAxis(irreps_in, factor)
133133
layer_compiled = torch.compile(layer, fullgraph=True)
134134

135-
assert torch.allclose(layer(input), layer_compiled(input))
135+
torch.testing.assert_close(layer(input), layer_compiled(input))

tests/test_scaling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_scale_irreps_by_one(irreps_in: str):
3030
weight = torch.ones(irreps_in.num_irreps)
3131
output = layer(input, weight)
3232

33-
assert torch.allclose(input, output)
33+
torch.testing.assert_close(input, output)
3434

3535

3636
@pytest.mark.parametrize("irreps_in", ["0e + 1o", "0e + 1o + 2e", "3x1o + 2x2o"])
@@ -46,4 +46,4 @@ def test_scale_irreps_random(irreps_in: str):
4646

4747
norm = e3nn.o3.Norm(irreps_in)
4848
factor = norm(output) / norm(input)
49-
assert torch.allclose(factor, torch.abs(weight))
49+
torch.testing.assert_close(factor, torch.abs(weight))

0 commit comments

Comments
 (0)