Skip to content

Commit db4afdb

Browse files
committed
Fix name.
1 parent 2b5fb7f commit db4afdb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/e3tools/nn/_extract_irreps.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
class ExtractIrreps(torch.nn.Module):
66
"""Extracts specific irreps from a e3nn tensor."""
77

8-
def __init__(self, irreps_in: e3nn.o3.Irreps, irreps_extract: e3nn.o3.Irrep):
8+
def __init__(self, irreps_in: e3nn.o3.Irreps, irrep_extract: e3nn.o3.Irrep):
99
super().__init__()
1010

1111
self.irreps_in = e3nn.o3.Irreps(irreps_in)
12-
self.irreps_extract = e3nn.o3.Irrep(irreps_extract)
12+
self.irrep_extract = e3nn.o3.Irrep(irrep_extract)
1313

1414
irreps_out = e3nn.o3.Irreps()
1515
slices = []
1616
for (mul, ir), ir_slice in zip(self.irreps_in, self.irreps_in.slices()):
17-
if ir.l == self.irreps_extract.l and ir.p == self.irreps_extract.p:
17+
if (ir.l, ir.p) == (self.irrep_extract.l, self.irrep_extract.p):
1818
slices.append(ir_slice)
1919
irreps_out += e3nn.o3.Irreps(f"{mul}x{ir}")
2020

2121
if len(slices) == 0:
22-
raise ValueError(f"irreps {irreps_extract} not found in {irreps_in}")
22+
raise ValueError(f"Irreps {irrep_extract} not found in {irreps_in}")
2323

2424
self.slices = slices
2525
self.irreps_out = irreps_out

0 commit comments

Comments
 (0)