Skip to content

Commit a9bc5fe

Browse files
committed
Fix layernorm.
1 parent b1c057a commit a9bc5fe

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/e3tools/nn/_layer_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
116116

117117
if is_scalar:
118118
# For scalar irreps (l=0, p=1), use standard layer norm
119-
field_norm = F.layer_norm(field_view, [dim], None, None, self.eps)
119+
field_norm = F.layer_norm(field_view, [mul, dim], None, None, self.eps)
120120
# Flatten back for concatenation
121121
field_out = field_norm.reshape(*batch_dims, size)
122122
else:

tests/test_layer_norm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def test_equivariance(irreps_in: str):
2424
[
2525
"0e + 1o",
2626
"32x0e + 1o + 2e",
27+
"0e + 4x1o + 5e",
2728
"3x1o + 2x2o",
2829
"8x1o + 2x2o + 1x3o",
2930
"3x1o + 2x2o + 1x3o + 1x4o",
@@ -43,7 +44,7 @@ def test_layer_norm_compiled(irreps_in: str, seed: int, batch_size: int = 8):
4344
assert torch.allclose(output, output_compiled)
4445

4546

46-
@pytest.mark.parametrize("irreps_in", ["0e + 1o", "32x0e + 1o + 2e", "3x1o + 2x2o"])
47+
@pytest.mark.parametrize("irreps_in", ["0e + 1o", "32x0e + 1o + 2e", "0e + 4x1o + 5e", "3x1o + 2x2o"])
4748
def test_layer_norm(irreps_in: str):
4849
irreps_in = e3nn.o3.Irreps(irreps_in)
4950
layer = LayerNorm(irreps_in)
@@ -53,9 +54,9 @@ def test_layer_norm(irreps_in: str):
5354
input = irreps_in.randn(-1)
5455
output = layer(input)
5556

56-
for _, ir, field in unpack_irreps(output, irreps_in):
57+
for mul, ir, field in unpack_irreps(output, irreps_in):
5758
sq_norms = field.norm(dim=-1, keepdim=True).pow(2).sum(dim=-1).mean(dim=-1)
58-
if ir.l == 0 and ir.p == 1:
59+
if ir.l == 0 and ir.p == 1 and mul == 1:
5960
assert torch.allclose(sq_norms, torch.as_tensor([0.0]))
6061
else:
6162
assert torch.allclose(sq_norms, torch.as_tensor([1.0]))

0 commit comments

Comments
 (0)