|
11 | 11 | from ._gate import Gated |
12 | 12 | from ._interaction import LinearSelfInteraction |
13 | 13 | from ._mlp import ScalarMLP |
14 | | -from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct |
| 14 | +from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct |
| 15 | + |
| 16 | +try: |
| 17 | + import openequivariance as oeq |
| 18 | + |
| 19 | + openequivariance_available = True |
| 20 | +except ImportError as e: |
| 21 | + error_msg = str(e) |
| 22 | + openequivariance_available = False |
| 23 | + |
| 24 | + |
| 25 | +class FusedConv(nn.Module): |
| 26 | + """ |
| 27 | + Fused version of equivariant convolution layer with OpenEquivariance kernels. |
| 28 | +
|
| 29 | + ref: https://arxiv.org/abs/1802.08219 |
| 30 | + ref: https://arxiv.org/abs/2501.13986 |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + irreps_in: Union[str, e3nn.o3.Irreps], |
| 36 | + irreps_out: Union[str, e3nn.o3.Irreps], |
| 37 | + irreps_sh: Union[str, e3nn.o3.Irreps], |
| 38 | + edge_attr_dim: int, |
| 39 | + radial_nn: Optional[Callable[..., nn.Module]] = None, |
| 40 | + tensor_product: Optional[Callable[..., nn.Module]] = None, |
| 41 | + ): |
| 42 | + """ |
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + irreps_in: e3nn.o3.Irreps |
| 46 | + Input node feature irreps |
| 47 | + irreps_out: e3nn.o3.Irreps |
| 48 | + Ouput node feature irreps |
| 49 | + irreps_sh: e3nn.o3.Irreps |
| 50 | + Edge spherical harmonic irreps |
| 51 | + edge_attr_dim: int |
| 52 | + Dimension of scalar edge attributes to be passed to radial_nn |
| 53 | + radial_nn: Optional[Callable[..., nn.Module]] |
| 54 | + Factory function for radial nn used to generate tensor product weights. |
| 55 | + Should be callable as radial_nn(in_features, out_features) |
| 56 | + if `None` then |
| 57 | + ``` |
| 58 | + functools.partial( |
| 59 | + e3tools.nn.ScalarMLP, |
| 60 | + hidden_features=[edge_attr_dim], |
| 61 | + activation_layer=nn.SiLU, |
| 62 | + ) |
| 63 | + ``` |
| 64 | + is used. |
| 65 | + tensor_product: Optional[Callable[..., nn.Module]] |
| 66 | + Factory function for tensor product used to mix input node |
| 67 | + representations with edge spherical harmonics. |
| 68 | + Should be callable as `tensor_product(irreps_in, irreps_sh, irreps_out)` |
| 69 | + and return an object with `weight_numel` property defined |
| 70 | + If `None` then |
| 71 | + ``` |
| 72 | + functools.partial( |
| 73 | + e3nn.o3.FullyConnectedTensorProduct |
| 74 | + shared_weights=False, |
| 75 | + internal_weights=False, |
| 76 | + ) |
| 77 | + ``` |
| 78 | + is used. |
| 79 | + """ |
| 80 | + |
| 81 | + super().__init__() |
| 82 | + |
| 83 | + self.irreps_in = e3nn.o3.Irreps(irreps_in) |
| 84 | + self.irreps_out = e3nn.o3.Irreps(irreps_out) |
| 85 | + self.irreps_sh = e3nn.o3.Irreps(irreps_sh) |
| 86 | + |
| 87 | + if tensor_product is None: |
| 88 | + tensor_product = functools.partial( |
| 89 | + e3nn.o3.FullyConnectedTensorProduct, |
| 90 | + shared_weights=False, |
| 91 | + internal_weights=False, |
| 92 | + ) |
| 93 | + |
| 94 | + self.tp = tensor_product(irreps_in, irreps_sh, irreps_out) |
| 95 | + if radial_nn is None: |
| 96 | + radial_nn = functools.partial( |
| 97 | + ScalarMLP, |
| 98 | + hidden_features=[edge_attr_dim], |
| 99 | + activation_layer=nn.SiLU, |
| 100 | + ) |
| 101 | + |
| 102 | + self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel) |
| 103 | + |
| 104 | + if not openequivariance_available: |
| 105 | + raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}") |
| 106 | + |
| 107 | + # Remove path weight and path shape from instructions. |
| 108 | + oeq_instructions = [instruction[:5] for instruction in self.tp.instructions] |
| 109 | + oeq_tpp = oeq.TPProblem( |
| 110 | + self.tp.irreps_in1, |
| 111 | + self.tp.irreps_in2, |
| 112 | + self.tp.irreps_out, |
| 113 | + oeq_instructions, |
| 114 | + shared_weights=False, |
| 115 | + internal_weights=False, |
| 116 | + ) |
| 117 | + self.fused_tp = oeq.TensorProductConv( |
| 118 | + oeq_tpp, torch_op=True, deterministic=False, use_opaque=False |
| 119 | + ) |
| 120 | + |
| 121 | + def forward( |
| 122 | + self, |
| 123 | + node_attr: torch.Tensor, |
| 124 | + edge_index: torch.Tensor, |
| 125 | + edge_attr: torch.Tensor, |
| 126 | + edge_sh: torch.Tensor, |
| 127 | + ) -> torch.Tensor: |
| 128 | + """ |
| 129 | + Computes the forward pass of the equivariant convolution. |
| 130 | +
|
| 131 | + Let N be the number of nodes, and E be the number of edges |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ---------- |
| 135 | + node_attr: [N, irreps_in.dim] |
| 136 | + edge_index: [2, E] |
| 137 | + edge_attr: [E, edge_attr_dim] |
| 138 | + edge_sh: [E, irreps_sh.dim] |
| 139 | +
|
| 140 | + Returns |
| 141 | + ------- |
| 142 | + out: [N, irreps_out.dim] |
| 143 | + """ |
| 144 | + N = node_attr.shape[0] |
| 145 | + |
| 146 | + src, dst = edge_index |
| 147 | + radial_attr = self.radial_nn(edge_attr) |
| 148 | + messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src) |
| 149 | + num_neighbors = scatter( |
| 150 | + torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum" |
| 151 | + ) |
| 152 | + out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(1) |
| 153 | + return out |
15 | 154 |
|
16 | 155 |
|
17 | 156 | class Conv(nn.Module): |
@@ -92,10 +231,21 @@ def __init__( |
92 | 231 |
|
93 | 232 | self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel) |
94 | 233 |
|
95 | | - def apply_per_edge(self, node_attr_src, edge_attr, edge_sh): |
| 234 | + def apply_per_edge( |
| 235 | + self, |
| 236 | + node_attr_src: torch.Tensor, |
| 237 | + edge_attr: torch.Tensor, |
| 238 | + edge_sh: torch.Tensor, |
| 239 | + ) -> torch.Tensor: |
96 | 240 | return self.tp(node_attr_src, edge_sh, self.radial_nn(edge_attr)) |
97 | 241 |
|
98 | | - def forward(self, node_attr, edge_index, edge_attr, edge_sh): |
| 242 | + def forward( |
| 243 | + self, |
| 244 | + node_attr: torch.Tensor, |
| 245 | + edge_index: torch.Tensor, |
| 246 | + edge_attr: torch.Tensor, |
| 247 | + edge_sh: torch.Tensor, |
| 248 | + ) -> torch.Tensor: |
99 | 249 | """ |
100 | 250 | Computes the forward pass of the equivariant convolution. |
101 | 251 |
|
@@ -137,12 +287,19 @@ def __init__(self, *args, **kwargs): |
137 | 287 | ) |
138 | 288 |
|
139 | 289 |
|
140 | | -class ExperimentalConv(Conv): |
| 290 | +class FusedDepthwiseConv(FusedConv): |
| 291 | + """ |
| 292 | + Equivariant convolution layer using separable tensor product |
| 293 | +
|
| 294 | + ref: https://arxiv.org/abs/1802.08219 |
| 295 | + ref: https://arxiv.org/abs/2206.11990 |
| 296 | + """ |
| 297 | + |
141 | 298 | def __init__(self, *args, **kwargs): |
142 | 299 | super().__init__( |
143 | 300 | *args, |
144 | 301 | **kwargs, |
145 | | - tensor_product=ExperimentalTensorProduct, |
| 302 | + tensor_product=DepthwiseTensorProduct, |
146 | 303 | ) |
147 | 304 |
|
148 | 305 |
|
|
0 commit comments