Skip to content

Commit 211bc48

Browse files
committed
move scatter op out from contracter module
1 parent ca6624e commit 211bc48

File tree

5 files changed

+46
-56
lines changed

5 files changed

+46
-56
lines changed

allegro/nn/_allegro.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ScalarMLPFunction,
1414
tp_path_exists,
1515
AvgNumNeighborsNorm,
16+
scatter,
1617
)
1718

1819
from ._strided import Contracter, MakeWeightedChannels
@@ -258,18 +259,36 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
258259
projection, -1, self.num_scalar_features, self._env_weighter.weight_numel
259260
)
260261

261-
# Get normalization tensor
262-
scatter_norm = self.avg_num_neighbors_norm(data)[:num_atoms].unsqueeze(-1)
263-
264262
layer_index: int = 0
265263
for latent, tp in zip(self.latents, self.tps):
266-
# === Env Weight & TP ===
264+
# === construct env weighted tensor ===
267265
env_w_edges = self._env_weighter(tensor_basis, env_w)
268-
# scatter env_w_edges and TP with tensor_features
269-
# second input irreps is the one that is scattered
266+
267+
# scatter env_w_edges to nodes and normalize
268+
env_w_scatter = scatter(
269+
env_w_edges,
270+
edge_center,
271+
dim=0,
272+
dim_size=num_atoms,
273+
)
274+
env_w_scatter_size0 = env_w_scatter.size(0)
275+
env_w_scatter_size1 = env_w_scatter.size(1)
276+
env_w_scatter_size2 = env_w_scatter.size(2)
277+
data[AtomicDataDict.NODE_FEATURES_KEY] = env_w_scatter.view(
278+
env_w_scatter_size0,
279+
env_w_scatter_size1 * env_w_scatter_size2,
280+
)
281+
data = self.avg_num_neighbors_norm(data)
282+
283+
# === TP ===
284+
# second input irreps is node-scattered env features
270285
irin1 = tensor_features
271-
irin2 = env_w_edges
272-
tensor_features = tp(irin1, irin2, edge_center, num_atoms, scatter_norm)
286+
irin2 = data[AtomicDataDict.NODE_FEATURES_KEY].view(
287+
env_w_scatter_size0,
288+
env_w_scatter_size1,
289+
env_w_scatter_size2,
290+
)
291+
tensor_features = tp(irin1, irin2, edge_center)
273292

274293
# Extract invariants from tensor track
275294
# features has shape [z][mul][k], where scalars are first

allegro/nn/_strided/_contract.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from e3nn.o3._irreps import Irreps
55
from e3nn.o3._wigner import wigner_3j
6-
from nequip.nn import scatter, replace_submodules, model_modifier
6+
from nequip.nn import replace_submodules, model_modifier
77
from nequip.utils.dtype import torch_default_dtype
88
from typing import List, Tuple, Optional
99

@@ -187,24 +187,12 @@ def forward(
187187
x1: torch.Tensor,
188188
x2: torch.Tensor,
189189
idxs: torch.Tensor,
190-
scatter_dim_size: int,
191-
scatter_norm: torch.Tensor,
192190
) -> torch.Tensor:
193-
# scatter and index select
194-
x2_scatter = scatter(
195-
x2,
196-
idxs,
197-
dim=0,
198-
dim_size=scatter_dim_size,
199-
)
200-
# normalization
201-
x2_scatter = x2_scatter * scatter_norm
202-
203191
# === perform TP ===
204192
# convert to strided shape
205193
x1 = x1.reshape(-1, self.mul, self.base_dim1)
206-
x2_scatter = x2_scatter.reshape(-1, self.mul, self.base_dim2)
207-
return self._contract_conv(x1, x2_scatter, idxs)
194+
x2 = x2.reshape(-1, self.mul, self.base_dim2)
195+
return self._contract_conv(x1, x2, idxs)
208196

209197
def _contract_conv(
210198
self, x1: torch.Tensor, x2: torch.Tensor, idxs: torch.Tensor

allegro/nn/_strided/_cueq_contracter.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This file is a part of the `allegro` package. Please see LICENSE and README at the root for information on using it.
22
import torch
33

4-
from nequip.nn import scatter
54
from ._contract import Contracter
65

76
import itertools
@@ -87,20 +86,12 @@ def forward(
8786
x1: torch.Tensor,
8887
x2: torch.Tensor,
8988
idxs: torch.Tensor,
90-
scatter_dim_size: int,
91-
scatter_norm: torch.Tensor,
9289
) -> torch.Tensor:
9390
# NOTE: the reason for some duplicated code is because TorchScript doesn't support super() calls
9491
# see https://github.com/pytorch/pytorch/issues/42885
9592

96-
x2_scatter = scatter(
97-
x2,
98-
idxs,
99-
dim=0,
100-
dim_size=scatter_dim_size,
101-
)
102-
103-
x2_scatter = x2_scatter * scatter_norm
93+
x1 = x1.reshape(-1, self.mul, self.base_dim1)
94+
x2 = x2.reshape(-1, self.mul, self.base_dim2)
10495

10596
if x1.is_cuda and self.num_paths >= 1:
10697
empty_dict: Dict[int, torch.Tensor] = {} # for torchscript
@@ -120,10 +111,10 @@ def forward(
120111
.view(
121112
x1.size(0), self.base_dim1 * self.mul
122113
), # (edges, irreps * mul)
123-
x2_scatter.transpose(1, 2)
114+
x2.transpose(1, 2)
124115
.contiguous()
125116
.view(
126-
scatter_dim_size, self.base_dim2 * self.mul
117+
x2.size(0), self.base_dim2 * self.mul
127118
), # (atoms, irreps * mul)
128119
],
129120
{2: idxs}, # input indices
@@ -138,4 +129,4 @@ def forward(
138129
.contiguous()
139130
)
140131
else:
141-
return self._contract_conv(x1, x2_scatter, idxs)
132+
return self._contract_conv(x1, x2, idxs)

allegro/nn/edgewise.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
5353
reduce=self.reduce,
5454
)
5555
# === scale ===
56-
factor = self.norm_module(data)[: AtomicDataDict.num_nodes(data)]
57-
out = out * (factor / sqrt(2))
56+
data[AtomicDataDict.NODE_FEATURES_KEY] = out
57+
data = self.norm_module(data)
58+
out = data[AtomicDataDict.NODE_FEATURES_KEY] / sqrt(2)
5859
# ^ factor of 2 to normalize dE/dr_i which includes both contributions from dE/dr_ij
5960
# and every other derivative against r_ji.
6061

allegro/utils/autotune_triton.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,11 @@ def create_nacl_supercell(supercell_size=10):
8686
]
8787

8888

89-
def benchmark_forward(
90-
contracter, input1, input2, scatter_idxs, num_atoms, warmup=3, n_iter=10
91-
):
89+
def benchmark_forward(contracter, input1, input2, scatter_idxs, warmup=3, n_iter=10):
9290
"""Benchmark forward pass."""
9391
# warmup
9492
for _ in range(warmup):
95-
_ = contracter(input1, input2, scatter_idxs, num_atoms)
93+
_ = contracter(input1, input2, scatter_idxs)
9694
torch.cuda.synchronize()
9795

9896
# benchmark
@@ -101,7 +99,7 @@ def benchmark_forward(
10199

102100
start_event.record()
103101
for _ in range(n_iter):
104-
_ = contracter(input1, input2, scatter_idxs, num_atoms)
102+
_ = contracter(input1, input2, scatter_idxs)
105103
end_event.record()
106104

107105
torch.cuda.synchronize()
@@ -110,9 +108,7 @@ def benchmark_forward(
110108
return total_time_ms / n_iter
111109

112110

113-
def benchmark_backward(
114-
contracter, input1, input2, scatter_idxs, num_atoms, warmup=3, n_iter=10
115-
):
111+
def benchmark_backward(contracter, input1, input2, scatter_idxs, warmup=3, n_iter=10):
116112
"""Benchmark full forward+backward pass.
117113
118114
Returns:
@@ -122,7 +118,7 @@ def benchmark_backward(
122118
for _ in range(warmup):
123119
input1.grad = None
124120
input2.grad = None
125-
out = contracter(input1, input2, scatter_idxs, num_atoms)
121+
out = contracter(input1, input2, scatter_idxs)
126122
grad_out = torch.randn_like(out)
127123
out.backward(grad_out)
128124
torch.cuda.synchronize()
@@ -135,7 +131,7 @@ def benchmark_backward(
135131
for _ in range(n_iter):
136132
input1.grad = None
137133
input2.grad = None
138-
out = contracter(input1, input2, scatter_idxs, num_atoms)
134+
out = contracter(input1, input2, scatter_idxs)
139135
grad_out = torch.randn_like(out)
140136
out.backward(grad_out)
141137
end_event.record()
@@ -188,7 +184,6 @@ def autotune(
188184
num_nodes = AtomicDataDict.num_nodes(data)
189185
num_edges = AtomicDataDict.num_edges(data)
190186
scatter_idxs = data[AtomicDataDict.EDGE_INDEX_KEY][1]
191-
num_atoms_tensor = torch.tensor([num_nodes], dtype=torch.int64, device=device)
192187

193188
print(f" num_nodes: {num_nodes}")
194189
print(f" num_edges: {num_edges}")
@@ -206,9 +201,9 @@ def autotune(
206201
irreps_in2 = model_config["irreps_in2"]
207202
mul = model_config["mul"]
208203

209-
# both inputs are edge-indexed, enable grad for backward
204+
# input1 is edge-indexed, input2 is node-indexed
210205
input1 = irreps_in1.randn(num_edges, mul, -1, dtype=dtype, device=device)
211-
input2 = irreps_in2.randn(num_edges, mul, -1, dtype=dtype, device=device)
206+
input2 = irreps_in2.randn(num_nodes, mul, -1, dtype=dtype, device=device)
212207
input1.requires_grad_(True)
213208
input2.requires_grad_(True)
214209

@@ -237,7 +232,6 @@ def autotune(
237232
input1,
238233
input2,
239234
scatter_idxs,
240-
num_atoms_tensor,
241235
warmup=5,
242236
n_iter=20,
243237
)
@@ -247,7 +241,6 @@ def autotune(
247241
input1,
248242
input2,
249243
scatter_idxs,
250-
num_atoms_tensor,
251244
warmup=5,
252245
n_iter=20,
253246
)
@@ -285,7 +278,6 @@ def autotune(
285278
input1,
286279
input2,
287280
scatter_idxs,
288-
num_atoms_tensor,
289281
warmup=5,
290282
n_iter=20,
291283
)
@@ -295,7 +287,6 @@ def autotune(
295287
input1,
296288
input2,
297289
scatter_idxs,
298-
num_atoms_tensor,
299290
warmup=5,
300291
n_iter=20,
301292
)

0 commit comments

Comments
 (0)