Skip to content

Commit 51b45bc

Browse files
committed
Fix API for OpenEquivariance.
1 parent a86468a commit 51b45bc

File tree

4 files changed

+69
-41
lines changed

4 files changed

+69
-41
lines changed

src/e3tools/nn/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
SeparableConv,
55
SeparableConvBlock,
66
FusedConv,
7-
FusedDepthwiseConv,
7+
FusedSeparableConv,
8+
FusedSeparableConvBlock,
89
)
910
from ._linear import Linear
1011
from ._gate import Gate, Gated, GateWrapper
@@ -26,7 +27,8 @@
2627
"EquivariantMLP",
2728
"ExtractIrreps",
2829
"FusedConv",
29-
"FusedDepthwiseConv",
30+
"FusedSeparableConv",
31+
"FusedSeparableConvBlock",
3032
"Gate",
3133
"GateWrapper",
3234
"Gated",

src/e3tools/nn/_conv.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from e3tools import scatter
1010

1111
from ._gate import Gated
12+
from ._linear import Linear
1213
from ._interaction import LinearSelfInteraction
1314
from ._mlp import ScalarMLP
1415
from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct
@@ -77,6 +78,8 @@ def __init__(
7778
```
7879
is used.
7980
"""
81+
if not openequivariance_available:
82+
raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}")
8083

8184
super().__init__()
8285

@@ -101,20 +104,24 @@ def __init__(
101104

102105
self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel)
103106

104-
if not openequivariance_available:
105-
raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}")
107+
if isinstance(self.tp, SeparableTensorProduct):
108+
tp = self.tp.dtp
109+
self.has_post_linear = True
110+
else:
111+
tp = self.tp
112+
self.has_post_linear = False
106113

107-
# Remove path weight and path shape from instructions.
108-
oeq_instructions = [instruction[:5] for instruction in self.tp.instructions]
114+
# Remove path weight and path shape from instructions for OpenEquivariance.
115+
oeq_instructions = [instruction[:5] for instruction in tp.instructions]
109116
oeq_tpp = oeq.TPProblem(
110-
self.tp.irreps_in1,
111-
self.tp.irreps_in2,
112-
self.tp.irreps_out,
117+
tp.irreps_in1,
118+
tp.irreps_in2,
119+
tp.irreps_out,
113120
oeq_instructions,
114121
shared_weights=False,
115122
internal_weights=False,
116123
)
117-
self.fused_tp = oeq.TensorProductConv(
124+
self.fused_tp_conv = oeq.TensorProductConv(
118125
oeq_tpp, torch_op=True, deterministic=False, use_opaque=False
119126
)
120127

@@ -145,7 +152,10 @@ def forward(
145152

146153
src, dst = edge_index
147154
radial_attr = self.radial_nn(edge_attr)
148-
messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src)
155+
messages_agg = self.fused_tp_conv(node_attr, edge_sh, radial_attr, dst, src)
156+
if self.has_post_linear:
157+
messages_agg = self.tp.lin(messages_agg)
158+
149159
num_neighbors = scatter(
150160
torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum"
151161
)
@@ -287,9 +297,9 @@ def __init__(self, *args, **kwargs):
287297
)
288298

289299

290-
class FusedDepthwiseConv(FusedConv):
300+
class FusedSeparableConv(FusedConv):
291301
"""
292-
Equivariant convolution layer using separable tensor product
302+
Equivariant convolution layer using separable tensor product, with fused OpenEquivariance kernels.
293303
294304
ref: https://arxiv.org/abs/1802.08219
295305
ref: https://arxiv.org/abs/2206.11990
@@ -299,7 +309,7 @@ def __init__(self, *args, **kwargs):
299309
super().__init__(
300310
*args,
301311
**kwargs,
302-
tensor_product=DepthwiseTensorProduct,
312+
tensor_product=SeparableTensorProduct,
303313
)
304314

305315

@@ -401,5 +411,22 @@ def __init__(self, *args, **kwargs):
401411
super().__init__(
402412
*args,
403413
**kwargs,
404-
conv=SeparableConv, # Explicitly set the convolution type to SeparableConv
414+
conv=SeparableConv,
415+
)
416+
417+
418+
class FusedSeparableConvBlock(ConvBlock):
419+
"""e3tools.nn.ConvBlock with FusedSeparableConv as the underlying convolution layer."""
420+
421+
def __init__(self, *args, **kwargs):
422+
"""
423+
Initializes the SeparableConvBlock.
424+
425+
All arguments are passed directly to the parent ConvBlock,
426+
with the 'conv' argument specifically set to SeparableConv.
427+
"""
428+
super().__init__(
429+
*args,
430+
**kwargs,
431+
conv=FusedSeparableConv,
405432
)

src/e3tools/nn/_tensor_product.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
internal_weights=False,
4949
shared_weights=False,
5050
)
51+
52+
# For book-keeping.
5153
self.irreps_out = self.tp.irreps_out
5254
self.weight_numel = self.tp.weight_numel
5355
self.instructions = self.tp.instructions
@@ -85,7 +87,6 @@ def __init__(
8587
self.lin = Linear(self.dtp.irreps_out, self.irreps_out)
8688

8789
# For book-keeping.
88-
self.instructions = self.dtp.instructions
8990
self.weight_numel = self.dtp.weight_numel
9091

9192
def forward(self, x, y, weight):

tests/test_fused_conv.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,33 @@
55
from torch import nn
66
import e3nn
77

8-
from e3tools.nn import Conv, FusedConv, DepthwiseTensorProduct, ScalarMLP
8+
from e3tools.nn import Conv, FusedConv, DepthwiseTensorProduct, ScalarMLP, SeparableConv, FusedSeparableConv, SeparableTensorProduct
99
from e3tools import radius_graph
1010

11-
TENSOR_PRODUCTS = [
12-
functools.partial(
13-
e3nn.o3.FullyConnectedTensorProduct,
14-
shared_weights=False,
15-
internal_weights=False,
16-
),
17-
DepthwiseTensorProduct,
18-
]
1911

20-
21-
@pytest.mark.parametrize("tensor_product", TENSOR_PRODUCTS)
12+
@pytest.mark.parametrize("tensor_product_type", [
13+
"default",
14+
"depthwise",
15+
"separable",
16+
])
2217
@pytest.mark.parametrize("seed", [0, 1])
23-
def test_fused_conv(tensor_product, seed):
18+
def test_fused_conv(tensor_product_type: str, seed: int):
2419
if not torch.cuda.is_available():
2520
pytest.skip("CUDA is not available")
2621

22+
if tensor_product_type == "default":
23+
tensor_product = functools.partial(
24+
e3nn.o3.FullyConnectedTensorProduct,
25+
shared_weights=False,
26+
internal_weights=False,
27+
)
28+
29+
elif tensor_product_type == "depthwise":
30+
tensor_product = DepthwiseTensorProduct
31+
32+
elif tensor_product_type == "separable":
33+
tensor_product = SeparableTensorProduct
34+
2735
torch.manual_seed(seed)
2836
torch.set_default_device("cuda")
2937

@@ -33,34 +41,24 @@ def test_fused_conv(tensor_product, seed):
3341
irreps_in = e3nn.o3.Irreps("10x0e + 4x1o + 1x2e")
3442
irreps_sh = irreps_in.spherical_harmonics(2)
3543

36-
tp = tensor_product(irreps_in, irreps_sh, irreps_in)
37-
common_radial_nn = ScalarMLP(
38-
in_features=edge_attr_dim,
39-
out_features=tp.weight_numel,
40-
hidden_features=[edge_attr_dim],
41-
activation_layer=nn.SiLU,
42-
)
43-
44-
def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module:
45-
return common_radial_nn
46-
4744
layer = Conv(
4845
irreps_in=irreps_in,
4946
irreps_out=irreps_in,
5047
irreps_sh=irreps_sh,
51-
radial_nn=radial_nn,
5248
edge_attr_dim=edge_attr_dim,
5349
tensor_product=tensor_product,
5450
)
5551
fused_layer = FusedConv(
5652
irreps_in=irreps_in,
5753
irreps_out=irreps_in,
5854
irreps_sh=irreps_sh,
59-
radial_nn=radial_nn,
6055
edge_attr_dim=edge_attr_dim,
6156
tensor_product=tensor_product,
6257
)
6358

59+
# Copy weights.
60+
fused_layer.load_state_dict(layer.state_dict())
61+
6462
pos = torch.randn(N, 3)
6563
node_attr = layer.irreps_in.randn(N, -1)
6664

0 commit comments

Comments
 (0)