diff --git a/CHANGELOG.md b/CHANGELOG.md index 709cced4b680..50956f5b9432 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Fixed +- Fixed `to_dense_adj` function by ordering the node idices by batch ([#10535](https://github.com/pyg-team/pytorch_geometric/pull/10535)) + ### Security ## [2.7.0] - 2025-10-14 diff --git a/test/utils/test_to_dense_adj.py b/test/utils/test_to_dense_adj.py index a6c8a94981f3..77a298150e5a 100644 --- a/test/utils/test_to_dense_adj.py +++ b/test/utils/test_to_dense_adj.py @@ -106,3 +106,28 @@ def test_to_dense_adj_with_duplicate_entries(): [0.0, 0.0, 13.0], [8.0, 0.0, 0.0], ] + + +def test_to_dense_adj_with_unordered_batch(): + edge_index = torch.tensor([ + [0, 1, 2, 3], + [3, 2, 1, 0], + ]) + batch = torch.tensor([0, 1, 1, 0]) + + adj = to_dense_adj(edge_index, batch) + assert adj.size() == (2, 2, 2) + assert adj[0].tolist() == [[0.0, 1.0], [1.0, 0.0]] + assert adj[1].tolist() == [[0.0, 1.0], [1.0, 0.0]] + + edge_index = torch.tensor([ + [0, 1, 2, 3], + [3, 2, 1, 0], + ]) + batch = torch.tensor([0, 1, 1, 0]) + edge_attr = torch.tensor([1.0, 3.0, 4.0, 2.0]) + + adj = to_dense_adj(edge_index, batch, edge_attr) + assert adj.size() == (2, 2, 2) + assert adj[0].tolist() == [[0.0, 1.0], [2.0, 0.0]] + assert adj[1].tolist() == [[0.0, 3.0], [4.0, 0.0]] diff --git a/torch_geometric/utils/_to_dense_adj.py b/torch_geometric/utils/_to_dense_adj.py index 5ca5758956e5..0ac174e69e72 100644 --- a/torch_geometric/utils/_to_dense_adj.py +++ b/torch_geometric/utils/_to_dense_adj.py @@ -67,6 +67,12 @@ def to_dense_adj( if batch_size is None: batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 + perm = batch.argsort() + batch = batch[perm] + new_index_map = torch.empty_like(perm) + new_index_map[perm] = torch.arange(perm.size(0)) + edge_index = new_index_map[edge_index] + one = batch.new_ones(batch.size(0)) num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum') cum_nodes = cumsum(num_nodes)