Skip to content

Conversation

@cathalobrien
Copy link

Hello,

this PR reduces the number of CUDA stream syncs when running torch_geometric/utils/_subgraph.py bipartite_subgraph() on GPU from 4 to 1.
This is done by replacing the use of nonzero() with nonzero_static() and replacing an indexing operation in torch_geometric/utils/mask.py index_to_mask() with scatter()

there is still 1 more sync, i think it is this indexing here which calls nonzero()

#torch_geometric/utils/_subgraph.py bipartite_subgraph()
    edge_index = edge_index[:, edge_mask]
    edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

the unit tests test/utils/test_mask.py and test/utils/test_subgraph.py pass.

See the screenshots below of pytorch perfetto traces before and after. This is accompanied by a speedup in my use case
Screenshot 2025-11-19 at 10 14 56
Screenshot 2025-11-19 at 11 30 48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant