Skip to content

Commit 25185e8

Browse files
authored
Fix aggregate nodes when last batch in a graph has no edges/nodes (#119)
1 parent d09b8a2 commit 25185e8

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

orb_models/forcefield/segment_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def aggregate_nodes(
4444
torch.use_deterministic_algorithms(True)
4545
segments = torch.arange(count, device=device).repeat_interleave(n_node)
4646
if reduction == "sum":
47-
return scatter_sum(tensor, segments, dim=0)
47+
return scatter_sum(tensor, segments, dim=0, dim_size=count)
4848
elif reduction == "mean":
49-
return scatter_mean(tensor, segments, dim=0)
49+
return scatter_mean(tensor, segments, dim=0, dim_size=count)
5050
elif reduction == "max":
51-
return segment_max(tensor, segments, count)
51+
return segment_max(tensor, segments, num_segments=count)
5252
else:
5353
raise ValueError("Invalid reduction argument. Use sum, mean or max.")
5454

tests/atomgraphs/test_segment_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def test_aggregate_nodes(reduction, dtype):
5858
assert torch.allclose(res[2, :], reduce_fn(tensor[8:, :]))
5959

6060

61+
@pytest.mark.parametrize("reduction", ["sum", "mean", "max"])
62+
def test_aggregate_nodes_with_zero_size_last_graph(reduction):
63+
tensor = torch.randn(10, 10)
64+
n_node = torch.tensor([8, 2, 0, 0], dtype=torch.int32)
65+
66+
res = segment_ops.aggregate_nodes(tensor, n_node=n_node, reduction=reduction)
67+
68+
assert res.shape == (4, 10)
69+
70+
6171
@pytest.mark.parametrize(
6272
"dtype", [torch.float32, torch.float64, torch.int32, torch.int64]
6373
)

0 commit comments

Comments
 (0)