Skip to content

Commit a86468a

Browse files
authored
Merge pull request #5 from prescient-design/openequivariance
Add support for OpenEquivariance Kernels
2 parents 2b6ef2a + 5b5b8b7 commit a86468a

14 files changed

+557
-213
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ dev = [
2929
"ruff",
3030
"ipykernel>=6.30.1",
3131
]
32+
openequivariance = [
33+
"openequivariance>=0.4.1",
34+
]
3235

3336
[tool.ruff.lint]
3437
ignore = ["F722"]

src/e3tools/nn/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
from ._conv import Conv, ConvBlock, ExperimentalConv, SeparableConv, SeparableConvBlock
1+
from ._conv import (
2+
Conv,
3+
ConvBlock,
4+
SeparableConv,
5+
SeparableConvBlock,
6+
FusedConv,
7+
FusedDepthwiseConv,
8+
)
29
from ._linear import Linear
310
from ._gate import Gate, Gated, GateWrapper
411
from ._interaction import LinearSelfInteraction
512
from ._layer_norm import LayerNorm
613
from ._mlp import EquivariantMLP, ScalarMLP
714
from ._axis_to_mul import AxisToMul, MulToAxis
8-
from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct
15+
from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct
916
from ._transformer import Attention, MultiheadAttention, TransformerBlock
1017
from ._extract_irreps import ExtractIrreps
1118
from ._scaling import ScaleIrreps
@@ -15,10 +22,11 @@
1522
"AxisToMul",
1623
"Conv",
1724
"ConvBlock",
25+
"DepthwiseTensorProduct",
1826
"EquivariantMLP",
19-
"ExperimentalConv",
20-
"ExperimentalTensorProduct",
2127
"ExtractIrreps",
28+
"FusedConv",
29+
"FusedDepthwiseConv",
2230
"Gate",
2331
"GateWrapper",
2432
"Gated",

src/e3tools/nn/_conv.py

Lines changed: 162 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,146 @@
1111
from ._gate import Gated
1212
from ._interaction import LinearSelfInteraction
1313
from ._mlp import ScalarMLP
14-
from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct
14+
from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct
15+
16+
try:
17+
import openequivariance as oeq
18+
19+
openequivariance_available = True
20+
except ImportError as e:
21+
error_msg = str(e)
22+
openequivariance_available = False
23+
24+
25+
class FusedConv(nn.Module):
26+
"""
27+
Fused version of equivariant convolution layer with OpenEquivariance kernels.
28+
29+
ref: https://arxiv.org/abs/1802.08219
30+
ref: https://arxiv.org/abs/2501.13986
31+
"""
32+
33+
def __init__(
34+
self,
35+
irreps_in: Union[str, e3nn.o3.Irreps],
36+
irreps_out: Union[str, e3nn.o3.Irreps],
37+
irreps_sh: Union[str, e3nn.o3.Irreps],
38+
edge_attr_dim: int,
39+
radial_nn: Optional[Callable[..., nn.Module]] = None,
40+
tensor_product: Optional[Callable[..., nn.Module]] = None,
41+
):
42+
"""
43+
Parameters
44+
----------
45+
irreps_in: e3nn.o3.Irreps
46+
Input node feature irreps
47+
irreps_out: e3nn.o3.Irreps
48+
Ouput node feature irreps
49+
irreps_sh: e3nn.o3.Irreps
50+
Edge spherical harmonic irreps
51+
edge_attr_dim: int
52+
Dimension of scalar edge attributes to be passed to radial_nn
53+
radial_nn: Optional[Callable[..., nn.Module]]
54+
Factory function for radial nn used to generate tensor product weights.
55+
Should be callable as radial_nn(in_features, out_features)
56+
if `None` then
57+
```
58+
functools.partial(
59+
e3tools.nn.ScalarMLP,
60+
hidden_features=[edge_attr_dim],
61+
activation_layer=nn.SiLU,
62+
)
63+
```
64+
is used.
65+
tensor_product: Optional[Callable[..., nn.Module]]
66+
Factory function for tensor product used to mix input node
67+
representations with edge spherical harmonics.
68+
Should be callable as `tensor_product(irreps_in, irreps_sh, irreps_out)`
69+
and return an object with `weight_numel` property defined
70+
If `None` then
71+
```
72+
functools.partial(
73+
e3nn.o3.FullyConnectedTensorProduct
74+
shared_weights=False,
75+
internal_weights=False,
76+
)
77+
```
78+
is used.
79+
"""
80+
81+
super().__init__()
82+
83+
self.irreps_in = e3nn.o3.Irreps(irreps_in)
84+
self.irreps_out = e3nn.o3.Irreps(irreps_out)
85+
self.irreps_sh = e3nn.o3.Irreps(irreps_sh)
86+
87+
if tensor_product is None:
88+
tensor_product = functools.partial(
89+
e3nn.o3.FullyConnectedTensorProduct,
90+
shared_weights=False,
91+
internal_weights=False,
92+
)
93+
94+
self.tp = tensor_product(irreps_in, irreps_sh, irreps_out)
95+
if radial_nn is None:
96+
radial_nn = functools.partial(
97+
ScalarMLP,
98+
hidden_features=[edge_attr_dim],
99+
activation_layer=nn.SiLU,
100+
)
101+
102+
self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel)
103+
104+
if not openequivariance_available:
105+
raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}")
106+
107+
# Remove path weight and path shape from instructions.
108+
oeq_instructions = [instruction[:5] for instruction in self.tp.instructions]
109+
oeq_tpp = oeq.TPProblem(
110+
self.tp.irreps_in1,
111+
self.tp.irreps_in2,
112+
self.tp.irreps_out,
113+
oeq_instructions,
114+
shared_weights=False,
115+
internal_weights=False,
116+
)
117+
self.fused_tp = oeq.TensorProductConv(
118+
oeq_tpp, torch_op=True, deterministic=False, use_opaque=False
119+
)
120+
121+
def forward(
122+
self,
123+
node_attr: torch.Tensor,
124+
edge_index: torch.Tensor,
125+
edge_attr: torch.Tensor,
126+
edge_sh: torch.Tensor,
127+
) -> torch.Tensor:
128+
"""
129+
Computes the forward pass of the equivariant convolution.
130+
131+
Let N be the number of nodes, and E be the number of edges
132+
133+
Parameters
134+
----------
135+
node_attr: [N, irreps_in.dim]
136+
edge_index: [2, E]
137+
edge_attr: [E, edge_attr_dim]
138+
edge_sh: [E, irreps_sh.dim]
139+
140+
Returns
141+
-------
142+
out: [N, irreps_out.dim]
143+
"""
144+
N = node_attr.shape[0]
145+
146+
src, dst = edge_index
147+
radial_attr = self.radial_nn(edge_attr)
148+
messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src)
149+
num_neighbors = scatter(
150+
torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum"
151+
)
152+
out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(1)
153+
return out
15154

16155

17156
class Conv(nn.Module):
@@ -92,10 +231,21 @@ def __init__(
92231

93232
self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel)
94233

95-
def apply_per_edge(self, node_attr_src, edge_attr, edge_sh):
234+
def apply_per_edge(
235+
self,
236+
node_attr_src: torch.Tensor,
237+
edge_attr: torch.Tensor,
238+
edge_sh: torch.Tensor,
239+
) -> torch.Tensor:
96240
return self.tp(node_attr_src, edge_sh, self.radial_nn(edge_attr))
97241

98-
def forward(self, node_attr, edge_index, edge_attr, edge_sh):
242+
def forward(
243+
self,
244+
node_attr: torch.Tensor,
245+
edge_index: torch.Tensor,
246+
edge_attr: torch.Tensor,
247+
edge_sh: torch.Tensor,
248+
) -> torch.Tensor:
99249
"""
100250
Computes the forward pass of the equivariant convolution.
101251
@@ -137,12 +287,19 @@ def __init__(self, *args, **kwargs):
137287
)
138288

139289

140-
class ExperimentalConv(Conv):
290+
class FusedDepthwiseConv(FusedConv):
291+
"""
292+
Equivariant convolution layer using separable tensor product
293+
294+
ref: https://arxiv.org/abs/1802.08219
295+
ref: https://arxiv.org/abs/2206.11990
296+
"""
297+
141298
def __init__(self, *args, **kwargs):
142299
super().__init__(
143300
*args,
144301
**kwargs,
145-
tensor_product=ExperimentalTensorProduct,
302+
tensor_product=DepthwiseTensorProduct,
146303
)
147304

148305

src/e3tools/nn/_tensor_product.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import e3nn
44
import e3nn.o3
5+
import torch
56
from torch import nn
67

78
from ._linear import Linear
89

910

10-
class SeparableTensorProduct(nn.Module):
11+
class DepthwiseTensorProduct(nn.Module):
1112
"""
12-
Tensor product factored into depthwise and pointwise components
13+
Depthwise tensor product
1314
1415
ref: https://arxiv.org/abs/2206.11990
1516
ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
@@ -24,45 +25,46 @@ def __init__(
2425
super().__init__()
2526
self.irreps_in1 = e3nn.o3.Irreps(irreps_in1)
2627
self.irreps_in2 = e3nn.o3.Irreps(irreps_in2)
27-
self.irreps_out = e3nn.o3.Irreps(irreps_out)
28+
irreps_out = e3nn.o3.Irreps(irreps_out)
2829

2930
irreps_out_dtp = []
3031
instructions_dtp = []
3132

3233
for i, (mul, ir_in1) in enumerate(self.irreps_in1):
3334
for j, (_, ir_in2) in enumerate(self.irreps_in2):
3435
for ir_out in ir_in1 * ir_in2:
35-
if ir_out in self.irreps_out or ir_out == e3nn.o3.Irrep(0, 1):
36+
if ir_out in irreps_out or ir_out == e3nn.o3.Irrep(0, 1):
3637
k = len(irreps_out_dtp)
3738
irreps_out_dtp.append((mul, ir_out))
3839
instructions_dtp.append((i, j, k, "uvu", True))
3940

4041
irreps_out_dtp = e3nn.o3.Irreps(irreps_out_dtp)
4142

42-
# depth wise
43-
self.dtp = e3nn.o3.TensorProduct(
43+
self.tp = e3nn.o3.TensorProduct(
4444
irreps_in1,
4545
irreps_in2,
4646
irreps_out_dtp,
4747
instructions_dtp,
4848
internal_weights=False,
4949
shared_weights=False,
5050
)
51-
52-
# point wise
53-
self.lin = Linear(irreps_out_dtp, self.irreps_out)
54-
55-
self.weight_numel = self.dtp.weight_numel
56-
57-
def forward(self, x, y, weight):
58-
out = self.dtp(x, y, weight)
59-
out = self.lin(out)
51+
self.irreps_out = self.tp.irreps_out
52+
self.weight_numel = self.tp.weight_numel
53+
self.instructions = self.tp.instructions
54+
55+
def forward(
56+
self, x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor
57+
) -> torch.Tensor:
58+
out = self.tp(x, y, weight)
6059
return out
6160

6261

63-
class ExperimentalTensorProduct(nn.Module):
62+
class SeparableTensorProduct(nn.Module):
6463
"""
65-
Compileable tensor product
64+
Tensor product factored into depthwise and pointwise components
65+
66+
ref: https://arxiv.org/abs/2206.11990
67+
ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
6668
"""
6769

6870
def __init__(
@@ -76,18 +78,17 @@ def __init__(
7678
self.irreps_in2 = e3nn.o3.Irreps(irreps_in2)
7779
self.irreps_out = e3nn.o3.Irreps(irreps_out)
7880

79-
self.tp = e3nn.o3.FullTensorProductv2(self.irreps_in1, self.irreps_in2)
80-
81-
self.lin = Linear(
82-
self.tp.irreps_out,
83-
self.irreps_out,
84-
internal_weights=False,
85-
shared_weights=False,
81+
# Depthwise and pointwise
82+
self.dtp = DepthwiseTensorProduct(
83+
self.irreps_in1, self.irreps_in2, self.irreps_out
8684
)
85+
self.lin = Linear(self.dtp.irreps_out, self.irreps_out)
8786

88-
self.weight_numel = self.lin.weight_numel
87+
# For book-keeping.
88+
self.instructions = self.dtp.instructions
89+
self.weight_numel = self.dtp.weight_numel
8990

9091
def forward(self, x, y, weight):
91-
out = self.tp(x, y)
92-
out = self.lin(out, weight)
92+
out = self.dtp(x, y, weight)
93+
out = self.lin(out)
9394
return out

src/e3tools/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._default_dtype_manager import default_dtype_manager
2+
3+
__all__ = ["default_dtype_manager"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from contextlib import contextmanager
2+
3+
import torch
4+
5+
6+
@contextmanager
7+
def default_dtype_manager(dtype):
8+
original_dtype = torch.get_default_dtype()
9+
try:
10+
torch.set_default_dtype(dtype)
11+
yield
12+
finally:
13+
torch.set_default_dtype(original_dtype)

0 commit comments

Comments
 (0)