99from e3tools import scatter
1010
1111from ._gate import Gated
12+ from ._linear import Linear
1213from ._interaction import LinearSelfInteraction
1314from ._mlp import ScalarMLP
1415from ._tensor_product import SeparableTensorProduct , DepthwiseTensorProduct
@@ -77,6 +78,8 @@ def __init__(
7778 ```
7879 is used.
7980 """
81+ if not openequivariance_available :
82+ raise ImportError (f"OpenEquivariance could not be imported:\n { error_msg } " )
8083
8184 super ().__init__ ()
8285
@@ -101,20 +104,24 @@ def __init__(
101104
102105 self .radial_nn = radial_nn (edge_attr_dim , self .tp .weight_numel )
103106
104- if not openequivariance_available :
105- raise ImportError (f"OpenEquivariance could not be imported:\n { error_msg } " )
107+ if isinstance (self .tp , SeparableTensorProduct ):
108+ tp = self .tp .dtp
109+ self .has_post_linear = True
110+ else :
111+ tp = self .tp
112+ self .has_post_linear = False
106113
107- # Remove path weight and path shape from instructions.
108- oeq_instructions = [instruction [:5 ] for instruction in self . tp .instructions ]
114+ # Remove path weight and path shape from instructions for OpenEquivariance .
115+ oeq_instructions = [instruction [:5 ] for instruction in tp .instructions ]
109116 oeq_tpp = oeq .TPProblem (
110- self . tp .irreps_in1 ,
111- self . tp .irreps_in2 ,
112- self . tp .irreps_out ,
117+ tp .irreps_in1 ,
118+ tp .irreps_in2 ,
119+ tp .irreps_out ,
113120 oeq_instructions ,
114121 shared_weights = False ,
115122 internal_weights = False ,
116123 )
117- self .fused_tp = oeq .TensorProductConv (
124+ self .fused_tp_conv = oeq .TensorProductConv (
118125 oeq_tpp , torch_op = True , deterministic = False , use_opaque = False
119126 )
120127
@@ -145,7 +152,10 @@ def forward(
145152
146153 src , dst = edge_index
147154 radial_attr = self .radial_nn (edge_attr )
148- messages_agg = self .fused_tp (node_attr , edge_sh , radial_attr , dst , src )
155+ messages_agg = self .fused_tp_conv (node_attr , edge_sh , radial_attr , dst , src )
156+ if self .has_post_linear :
157+ messages_agg = self .tp .lin (messages_agg )
158+
149159 num_neighbors = scatter (
150160 torch .ones_like (src ), src , dim = 0 , dim_size = N , reduce = "sum"
151161 )
@@ -287,9 +297,9 @@ def __init__(self, *args, **kwargs):
287297 )
288298
289299
290- class FusedDepthwiseConv (FusedConv ):
300+ class FusedSeparableConv (FusedConv ):
291301 """
292- Equivariant convolution layer using separable tensor product
302+ Equivariant convolution layer using separable tensor product, with fused OpenEquivariance kernels.
293303
294304 ref: https://arxiv.org/abs/1802.08219
295305 ref: https://arxiv.org/abs/2206.11990
@@ -299,7 +309,7 @@ def __init__(self, *args, **kwargs):
299309 super ().__init__ (
300310 * args ,
301311 ** kwargs ,
302- tensor_product = DepthwiseTensorProduct ,
312+ tensor_product = SeparableTensorProduct ,
303313 )
304314
305315
@@ -401,5 +411,22 @@ def __init__(self, *args, **kwargs):
401411 super ().__init__ (
402412 * args ,
403413 ** kwargs ,
404- conv = SeparableConv , # Explicitly set the convolution type to SeparableConv
414+ conv = SeparableConv ,
415+ )
416+
417+
418+ class FusedSeparableConvBlock (ConvBlock ):
419+ """e3tools.nn.ConvBlock with FusedSeparableConv as the underlying convolution layer."""
420+
421+ def __init__ (self , * args , ** kwargs ):
422+ """
423+ Initializes the SeparableConvBlock.
424+
425+ All arguments are passed directly to the parent ConvBlock,
426+ with the 'conv' argument specifically set to SeparableConv.
427+ """
428+ super ().__init__ (
429+ * args ,
430+ ** kwargs ,
431+ conv = FusedSeparableConv ,
405432 )
0 commit comments