Skip to content

Commit a665501

Browse files
committed
Fix test notation.
1 parent e3c56ba commit a665501

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

tests/test_basic.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Tuple
12
import functools
23

34
import pytest
@@ -23,7 +24,8 @@
2324
CONV_LAYERS = [Conv, SeparableConv, ExperimentalConv]
2425

2526

26-
def apply_layer_rot(layer):
27+
def apply_layer_rotation(layer: torch.nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
28+
"""Applies a rotation and returns the output of the layer with the rotation applied before and after."""
2729
N = 20
2830
edge_attr_dim = 10
2931
max_radius = 1.3
@@ -43,11 +45,11 @@ def apply_layer_rot(layer):
4345
cutoff=True,
4446
)
4547

46-
edge_sh = o3.spherical_harmonics(
48+
edge_sh = e3nn.o3.spherical_harmonics(
4749
layer.irreps_sh, edge_vec, True, normalization="component"
4850
)
4951

50-
rot = o3.rand_matrix()
52+
rot = e3nn.o3.rand_matrix()
5153

5254
D_node_attr = layer.irreps_in.D_from_matrix(rot)
5355
D_edge_sh = layer.irreps_sh.D_from_matrix(rot)
@@ -64,33 +66,33 @@ def apply_layer_rot(layer):
6466

6567
@pytest.mark.parametrize("conv", CONV_LAYERS)
6668
def test_conv(conv):
67-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
69+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
6870
irreps_sh = irreps_in.spherical_harmonics(2)
6971
edge_attr_dim = 10
7072

7173
layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim)
7274

73-
out_1, out_2 = apply_layer_rot(layer)
75+
out_1, out_2 = apply_layer_rotation(layer)
7476
assert torch.allclose(out_1, out_2, atol=1e-10)
7577

7678

7779
@pytest.mark.parametrize("conv", CONV_LAYERS)
7880
def test_gated_conv(conv):
79-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
81+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
8082
irreps_sh = irreps_in.spherical_harmonics(2)
8183
edge_attr_dim = 10
8284

8385
wrapped = functools.partial(conv, irreps_sh=irreps_sh, edge_attr_dim=edge_attr_dim)
8486

8587
layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in)
8688

87-
out_1, out_2 = apply_layer_rot(layer)
89+
out_1, out_2 = apply_layer_rotation(layer)
8890
assert torch.allclose(out_1, out_2, atol=1e-10)
8991

9092

9193
@pytest.mark.parametrize("conv", CONV_LAYERS)
9294
def test_conv_block(conv):
93-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
95+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
9496
irreps_sh = irreps_in.spherical_harmonics(2)
9597
edge_attr_dim = 10
9698

@@ -102,13 +104,13 @@ def test_conv_block(conv):
102104
conv=conv,
103105
)
104106

105-
out_1, out_2 = apply_layer_rot(layer)
107+
out_1, out_2 = apply_layer_rotation(layer)
106108
assert torch.allclose(out_1, out_2, atol=1e-10)
107109

108110

109111
@pytest.mark.parametrize("conv", CONV_LAYERS)
110112
def test_attention(conv):
111-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
113+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
112114
irreps_out = irreps_in
113115
irreps_sh = irreps_in.spherical_harmonics(2)
114116
irreps_key = irreps_in
@@ -125,13 +127,13 @@ def test_attention(conv):
125127
conv=conv,
126128
)
127129

128-
out_1, out_2 = apply_layer_rot(layer)
130+
out_1, out_2 = apply_layer_rotation(layer)
129131
assert torch.allclose(out_1, out_2, atol=1e-10)
130132

131133

132134
@pytest.mark.parametrize("conv", [Conv, SeparableConv])
133135
def test_multihead_attention(conv):
134-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
136+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
135137
irreps_out = irreps_in
136138
irreps_sh = irreps_in.spherical_harmonics(2)
137139
irreps_key = irreps_in
@@ -150,15 +152,15 @@ def test_multihead_attention(conv):
150152
conv=conv,
151153
)
152154

153-
out_1, out_2 = apply_layer_rot(layer)
155+
out_1, out_2 = apply_layer_rotation(layer)
154156
assert torch.allclose(out_1, out_2, atol=1e-10)
155157

156158

157159
def test_layer_norm():
158-
irreps = o3.Irreps("10x0e + 10x1o + 10x2e")
160+
irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
159161

160162
layer = LayerNorm(irreps)
161-
rot = o3.rand_matrix()
163+
rot = e3nn.o3.rand_matrix()
162164
D = irreps.D_from_matrix(rot)
163165

164166
x = irreps.randn(10, -1)
@@ -170,14 +172,14 @@ def test_layer_norm():
170172

171173

172174
def test_equivariant_mlp():
173-
irreps = o3.Irreps("10x0e + 10x1o + 10x2e")
174-
irreps_hidden = o3.Irreps([(4 * mul, ir) for mul, ir in irreps])
175+
irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
176+
irreps_hidden = e3nn.o3.Irreps([(4 * mul, ir) for mul, ir in irreps])
175177

176178
layer = EquivariantMLP(
177179
irreps, irreps, [irreps_hidden, irreps_hidden], norm_layer=LayerNorm
178180
)
179181

180-
rot = o3.rand_matrix()
182+
rot = e3nn.o3.rand_matrix()
181183
D = irreps.D_from_matrix(rot)
182184

183185
x = irreps.randn(10, -1)
@@ -189,7 +191,7 @@ def test_equivariant_mlp():
189191

190192

191193
def test_transformer():
192-
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
194+
irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e")
193195
irreps_out = irreps_in
194196
irreps_sh = irreps_in.spherical_harmonics(2)
195197
edge_attr_dim = 10
@@ -203,5 +205,5 @@ def test_transformer():
203205
n_head=n_head,
204206
)
205207

206-
out_1, out_2 = apply_layer_rot(layer)
208+
out_1, out_2 = apply_layer_rotation(layer)
207209
assert torch.allclose(out_1, out_2, atol=1e-10)

0 commit comments

Comments
 (0)