Skip to content

Commit 6d05192

Browse files
committed
propagate changes to tests and tutorial config
1 parent 211bc48 commit 6d05192

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

configs/tutorial.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
run: [train, test]
33

44
cutoff_radius: 5.0
5-
model_type_names: [C, H, O, Cu]
5+
model_type_names: [C, H, O]
66
chemical_species: ${model_type_names}
77

88
data:
@@ -139,7 +139,7 @@ training_module:
139139

140140
# === misc hyperparameters ===
141141
# average number of neighbors for edge sum normalization
142-
avg_num_neighbors: ${training_data_stats:num_neighbors_mean}
142+
avg_num_neighbors: ${training_data_stats:per_type_num_neighbors_mean}
143143

144144
# per-type per-atom scales and shifts
145145
per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean}

tests/nn/test_contract_basic.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,23 @@ def test_contract_jit(
5757
c_opt_mod.weights.copy_(c_base.weights)
5858
c_opt_mod = torch.jit.script(c_opt_mod)
5959

60-
def c_opt(x, y, idx, dim, w=None):
61-
args = (x, y, idx, dim, w)
62-
if w is None:
63-
args = args[:-1]
64-
return c_opt_mod(*args)
60+
def c_opt(x, y, idx):
61+
return c_opt_mod(x, y, idx)
6562

66-
batchdim = 17
67-
scatter_dim = torch.tensor([batchdim], dtype=torch.long, device=device)
68-
scatter_idxs = torch.arange(batchdim, device=device)
63+
num_edges = 17
64+
num_nodes = 11
65+
scatter_idxs = torch.randint(0, num_nodes, (num_edges,), device=device)
6966
args_in = (
70-
irreps_in1.randn(batchdim, mul, -1, device=device),
71-
irreps_in2.randn(batchdim, mul, -1, device=device),
67+
irreps_in1.randn(num_edges, mul, -1, device=device),
68+
irreps_in2.randn(num_nodes, mul, -1, device=device),
7269
scatter_idxs,
73-
scatter_dim,
74-
torch.randn(
75-
tuple(batchdim if e == -1 else e for e in c_base.weights.shape)
76-
),
7770
)
78-
args_in = args_in[:-1]
7971

8072
for c in (c_base, c_opt):
8173
assert_equivariant(
8274
c,
8375
args_in=args_in,
84-
irreps_in=[irreps_in1, irreps_in2, None, None],
76+
irreps_in=[irreps_in1, irreps_in2, None],
8577
irreps_out=irreps_out,
8678
# e3nn uses 1e-3, 1e-9
8779
tolerance={torch.float32: 1e-3, torch.float64: 1e-8}[
@@ -164,7 +156,6 @@ def test_like_tp(
164156
# make input data
165157
batchdim = 1
166158
scatter_idxs = torch.arange(batchdim, device=device)
167-
scatter_dim = torch.tensor([batchdim], dtype=torch.long, device=device)
168159
tensor1 = torch.randn(batchdim, mul, c.base_dim1, device=device)
169160
tensor2 = torch.randn(batchdim, mul, c.base_dim2, device=device)
170161

@@ -193,9 +184,7 @@ def test_like_tp(
193184
weights_tp = weights_tp.T
194185
# else weights are just (u,)
195186
weights_tp = weights_tp.reshape(-1)
196-
c_out = _strided_to_cat(
197-
irreps_out, mul, c(tensor1, tensor2, scatter_idxs, scatter_dim)
198-
)
187+
c_out = _strided_to_cat(irreps_out, mul, c(tensor1, tensor2, scatter_idxs))
199188
tp_out = tp(
200189
_strided_to_cat(irreps_in1, mul, tensor1),
201190
_strided_to_cat(irreps_in2, mul, tensor2),

tests/nn/test_contract_kernels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,15 @@ def test_contract_kernel(
9797
)
9898
args_in = (
9999
irreps_in1.randn(num_edges, mul, -1, device=device),
100-
irreps_in2.randn(num_edges, mul, -1, device=device),
100+
irreps_in2.randn(num_atoms, mul, -1, device=device),
101101
scatter_idxs,
102-
torch.tensor([num_atoms], dtype=torch.int64, device=device),
103102
)
104103

105104
for c in (c_base, c_kernel):
106105
assert_equivariant(
107106
c,
108107
args_in=args_in,
109-
irreps_in=[irreps_in1, irreps_in2, None, None],
108+
irreps_in=[irreps_in1, irreps_in2, None],
110109
irreps_out=irreps_out,
111110
# e3nn uses 1e-3, 1e-9
112111
tolerance={torch.float32: 1e-3, torch.float64: 1e-8}[

0 commit comments

Comments
 (0)