File tree Expand file tree Collapse file tree 3 files changed +30
-10
lines changed
Expand file tree Collapse file tree 3 files changed +30
-10
lines changed Original file line number Diff line number Diff line change 1717from ._transformer import Attention , MultiheadAttention , TransformerBlock
1818from ._extract_irreps import ExtractIrreps
1919from ._scaling import ScaleIrreps
20+ from ._repeat import Repeat
2021
2122__all__ = [
2223 "Attention" ,
3738 "LinearSelfInteraction" ,
3839 "MulToAxis" ,
3940 "MultiheadAttention" ,
41+ "Repeat" ,
4042 "ScalarMLP" ,
4143 "ScaleIrreps" ,
4244 "SeparableConv" ,
Original file line number Diff line number Diff line change 1313from ._mlp import ScalarMLP
1414from ._tensor_product import SeparableTensorProduct
1515
16- try :
17- import openequivariance as oeq
18-
19- openequivariance_available = True
20- except ImportError as e :
21- error_msg = str (e )
22- openequivariance_available = False
23-
2416
2517class FusedConv (nn .Module ):
2618 """
@@ -77,8 +69,10 @@ def __init__(
7769 ```
7870 is used.
7971 """
80- if not openequivariance_available :
81- raise ImportError (f"OpenEquivariance could not be imported:\n { error_msg } " )
72+ try :
73+ import openequivariance as oeq
74+ except ImportError as e :
75+ raise ImportError (f"OpenEquivariance could not be imported: { e } " )
8276
8377 super ().__init__ ()
8478
Original file line number Diff line number Diff line change 1+ import torch
2+ from torch import nn
3+ import e3nn .o3
4+
5+
6+
7+ class Repeat (nn .Module ):
8+ """Repeat the irreps along the last axis."""
9+
10+ def __init__ (self , irreps_in : e3nn .o3 .Irreps , repeats : int ):
11+ super ().__init__ ()
12+ self .irreps_in = e3nn .o3 .Irreps (irreps_in )
13+ self .repeats = repeats
14+ self .irreps_out = e3nn .o3 .Irreps (
15+ [(mul * repeats , ir ) for mul , ir in irreps_in ]
16+ )
17+
18+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
19+ """Repeat the features along the last axis.
20+
21+ If `x` has shape `[..., irreps_in.dim]`, the output will have shape
22+ `[..., repeats * irreps_in.dim]`.
23+ """
24+ return x .repeat_interleave (self .repeats , dim = - 1 )
You can’t perform that action at this time.
0 commit comments