Skip to content

Commit 73f970d

Browse files
committed
Make openeq import only with FusedConv. Add Repeat().
1 parent d08f3f6 commit 73f970d

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

src/e3tools/nn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ._transformer import Attention, MultiheadAttention, TransformerBlock
1818
from ._extract_irreps import ExtractIrreps
1919
from ._scaling import ScaleIrreps
20+
from ._repeat import Repeat
2021

2122
__all__ = [
2223
"Attention",
@@ -37,6 +38,7 @@
3738
"LinearSelfInteraction",
3839
"MulToAxis",
3940
"MultiheadAttention",
41+
"Repeat",
4042
"ScalarMLP",
4143
"ScaleIrreps",
4244
"SeparableConv",

src/e3tools/nn/_conv.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313
from ._mlp import ScalarMLP
1414
from ._tensor_product import SeparableTensorProduct
1515

16-
try:
17-
import openequivariance as oeq
18-
19-
openequivariance_available = True
20-
except ImportError as e:
21-
error_msg = str(e)
22-
openequivariance_available = False
23-
2416

2517
class FusedConv(nn.Module):
2618
"""
@@ -77,8 +69,10 @@ def __init__(
7769
```
7870
is used.
7971
"""
80-
if not openequivariance_available:
81-
raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}")
72+
try:
73+
import openequivariance as oeq
74+
except ImportError as e:
75+
raise ImportError(f"OpenEquivariance could not be imported: {e}")
8276

8377
super().__init__()
8478

src/e3tools/nn/_repeat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from torch import nn
3+
import e3nn.o3
4+
5+
6+
7+
class Repeat(nn.Module):
8+
"""Repeat the irreps along the last axis."""
9+
10+
def __init__(self, irreps_in: e3nn.o3.Irreps, repeats: int):
11+
super().__init__()
12+
self.irreps_in = e3nn.o3.Irreps(irreps_in)
13+
self.repeats = repeats
14+
self.irreps_out = e3nn.o3.Irreps(
15+
[(mul * repeats, ir) for mul, ir in irreps_in]
16+
)
17+
18+
def forward(self, x: torch.Tensor) -> torch.Tensor:
19+
"""Repeat the features along the last axis.
20+
21+
If `x` has shape `[..., irreps_in.dim]`, the output will have shape
22+
`[..., repeats * irreps_in.dim]`.
23+
"""
24+
return x.repeat_interleave(self.repeats, dim=-1)

0 commit comments

Comments
 (0)