@@ -24,6 +24,7 @@ def test_equivariance(irreps_in: str):
2424 [
2525 "0e + 1o" ,
2626 "32x0e + 1o + 2e" ,
27+ "0e + 4x1o + 5e" ,
2728 "3x1o + 2x2o" ,
2829 "8x1o + 2x2o + 1x3o" ,
2930 "3x1o + 2x2o + 1x3o + 1x4o" ,
@@ -43,7 +44,7 @@ def test_layer_norm_compiled(irreps_in: str, seed: int, batch_size: int = 8):
4344 assert torch .allclose (output , output_compiled )
4445
4546
46- @pytest .mark .parametrize ("irreps_in" , ["0e + 1o" , "32x0e + 1o + 2e" , "3x1o + 2x2o" ])
47+ @pytest .mark .parametrize ("irreps_in" , ["0e + 1o" , "32x0e + 1o + 2e" , "0e + 4x1o + 5e" , " 3x1o + 2x2o" ])
4748def test_layer_norm (irreps_in : str ):
4849 irreps_in = e3nn .o3 .Irreps (irreps_in )
4950 layer = LayerNorm (irreps_in )
@@ -53,9 +54,9 @@ def test_layer_norm(irreps_in: str):
5354 input = irreps_in .randn (- 1 )
5455 output = layer (input )
5556
56- for _ , ir , field in unpack_irreps (output , irreps_in ):
57+ for mul , ir , field in unpack_irreps (output , irreps_in ):
5758 sq_norms = field .norm (dim = - 1 , keepdim = True ).pow (2 ).sum (dim = - 1 ).mean (dim = - 1 )
58- if ir .l == 0 and ir .p == 1 :
59+ if ir .l == 0 and ir .p == 1 and mul == 1 :
5960 assert torch .allclose (sq_norms , torch .as_tensor ([0.0 ]))
6061 else :
6162 assert torch .allclose (sq_norms , torch .as_tensor ([1.0 ]))
0 commit comments