1+ from typing import Tuple
12import functools
23
34import pytest
2324CONV_LAYERS = [Conv , SeparableConv , ExperimentalConv ]
2425
2526
26- def apply_layer_rot (layer ):
27+ def apply_layer_rotation (layer : torch .nn .Module ) -> Tuple [torch .Tensor , torch .Tensor ]:
28+ """Applies a rotation and returns the output of the layer with the rotation applied before and after."""
2729 N = 20
2830 edge_attr_dim = 10
2931 max_radius = 1.3
@@ -43,11 +45,11 @@ def apply_layer_rot(layer):
4345 cutoff = True ,
4446 )
4547
46- edge_sh = o3 .spherical_harmonics (
48+ edge_sh = e3nn . o3 .spherical_harmonics (
4749 layer .irreps_sh , edge_vec , True , normalization = "component"
4850 )
4951
50- rot = o3 .rand_matrix ()
52+ rot = e3nn . o3 .rand_matrix ()
5153
5254 D_node_attr = layer .irreps_in .D_from_matrix (rot )
5355 D_edge_sh = layer .irreps_sh .D_from_matrix (rot )
@@ -64,33 +66,33 @@ def apply_layer_rot(layer):
6466
6567@pytest .mark .parametrize ("conv" , CONV_LAYERS )
6668def test_conv (conv ):
67- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
69+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
6870 irreps_sh = irreps_in .spherical_harmonics (2 )
6971 edge_attr_dim = 10
7072
7173 layer = conv (irreps_in , irreps_in , irreps_sh , edge_attr_dim = edge_attr_dim )
7274
73- out_1 , out_2 = apply_layer_rot (layer )
75+ out_1 , out_2 = apply_layer_rotation (layer )
7476 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
7577
7678
7779@pytest .mark .parametrize ("conv" , CONV_LAYERS )
7880def test_gated_conv (conv ):
79- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
81+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
8082 irreps_sh = irreps_in .spherical_harmonics (2 )
8183 edge_attr_dim = 10
8284
8385 wrapped = functools .partial (conv , irreps_sh = irreps_sh , edge_attr_dim = edge_attr_dim )
8486
8587 layer = Gated (wrapped , irreps_in = irreps_in , irreps_out = irreps_in )
8688
87- out_1 , out_2 = apply_layer_rot (layer )
89+ out_1 , out_2 = apply_layer_rotation (layer )
8890 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
8991
9092
9193@pytest .mark .parametrize ("conv" , CONV_LAYERS )
9294def test_conv_block (conv ):
93- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
95+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
9496 irreps_sh = irreps_in .spherical_harmonics (2 )
9597 edge_attr_dim = 10
9698
@@ -102,13 +104,13 @@ def test_conv_block(conv):
102104 conv = conv ,
103105 )
104106
105- out_1 , out_2 = apply_layer_rot (layer )
107+ out_1 , out_2 = apply_layer_rotation (layer )
106108 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
107109
108110
109111@pytest .mark .parametrize ("conv" , CONV_LAYERS )
110112def test_attention (conv ):
111- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
113+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
112114 irreps_out = irreps_in
113115 irreps_sh = irreps_in .spherical_harmonics (2 )
114116 irreps_key = irreps_in
@@ -125,13 +127,13 @@ def test_attention(conv):
125127 conv = conv ,
126128 )
127129
128- out_1 , out_2 = apply_layer_rot (layer )
130+ out_1 , out_2 = apply_layer_rotation (layer )
129131 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
130132
131133
132134@pytest .mark .parametrize ("conv" , [Conv , SeparableConv ])
133135def test_multihead_attention (conv ):
134- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
136+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
135137 irreps_out = irreps_in
136138 irreps_sh = irreps_in .spherical_harmonics (2 )
137139 irreps_key = irreps_in
@@ -150,15 +152,15 @@ def test_multihead_attention(conv):
150152 conv = conv ,
151153 )
152154
153- out_1 , out_2 = apply_layer_rot (layer )
155+ out_1 , out_2 = apply_layer_rotation (layer )
154156 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
155157
156158
157159def test_layer_norm ():
158- irreps = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
160+ irreps = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
159161
160162 layer = LayerNorm (irreps )
161- rot = o3 .rand_matrix ()
163+ rot = e3nn . o3 .rand_matrix ()
162164 D = irreps .D_from_matrix (rot )
163165
164166 x = irreps .randn (10 , - 1 )
@@ -170,14 +172,14 @@ def test_layer_norm():
170172
171173
172174def test_equivariant_mlp ():
173- irreps = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
174- irreps_hidden = o3 .Irreps ([(4 * mul , ir ) for mul , ir in irreps ])
175+ irreps = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
176+ irreps_hidden = e3nn . o3 .Irreps ([(4 * mul , ir ) for mul , ir in irreps ])
175177
176178 layer = EquivariantMLP (
177179 irreps , irreps , [irreps_hidden , irreps_hidden ], norm_layer = LayerNorm
178180 )
179181
180- rot = o3 .rand_matrix ()
182+ rot = e3nn . o3 .rand_matrix ()
181183 D = irreps .D_from_matrix (rot )
182184
183185 x = irreps .randn (10 , - 1 )
@@ -189,7 +191,7 @@ def test_equivariant_mlp():
189191
190192
191193def test_transformer ():
192- irreps_in = o3 .Irreps ("10x0e + 10x1o + 10x2e" )
194+ irreps_in = e3nn . o3 .Irreps ("10x0e + 10x1o + 10x2e" )
193195 irreps_out = irreps_in
194196 irreps_sh = irreps_in .spherical_harmonics (2 )
195197 edge_attr_dim = 10
@@ -203,5 +205,5 @@ def test_transformer():
203205 n_head = n_head ,
204206 )
205207
206- out_1 , out_2 = apply_layer_rot (layer )
208+ out_1 , out_2 = apply_layer_rotation (layer )
207209 assert torch .allclose (out_1 , out_2 , atol = 1e-10 )
0 commit comments