File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 55class ExtractIrreps (torch .nn .Module ):
66 """Extracts specific irreps from a e3nn tensor."""
77
8- def __init__ (self , irreps_in : e3nn .o3 .Irreps , irreps_extract : e3nn .o3 .Irrep ):
8+ def __init__ (self , irreps_in : e3nn .o3 .Irreps , irrep_extract : e3nn .o3 .Irrep ):
99 super ().__init__ ()
1010
1111 self .irreps_in = e3nn .o3 .Irreps (irreps_in )
12- self .irreps_extract = e3nn .o3 .Irrep (irreps_extract )
12+ self .irrep_extract = e3nn .o3 .Irrep (irrep_extract )
1313
1414 irreps_out = e3nn .o3 .Irreps ()
1515 slices = []
1616 for (mul , ir ), ir_slice in zip (self .irreps_in , self .irreps_in .slices ()):
17- if ir .l == self .irreps_extract . l and ir . p == self .irreps_extract . p :
17+ if ( ir .l , ir . p ) == ( self .irrep_extract . l , self .irrep_extract . p ) :
1818 slices .append (ir_slice )
1919 irreps_out += e3nn .o3 .Irreps (f"{ mul } x{ ir } " )
2020
2121 if len (slices ) == 0 :
22- raise ValueError (f"irreps { irreps_extract } not found in { irreps_in } " )
22+ raise ValueError (f"Irreps { irrep_extract } not found in { irreps_in } " )
2323
2424 self .slices = slices
2525 self .irreps_out = irreps_out
You can’t perform that action at this time.
0 commit comments