From 0ad26903707feb9a386b0925a7b05acfd3378950 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Wed, 19 Nov 2025 12:38:27 +0000 Subject: [PATCH 1/2] reduce syncs --- torch_geometric/utils/_subgraph.py | 7 ++++--- torch_geometric/utils/mask.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_geometric/utils/_subgraph.py b/torch_geometric/utils/_subgraph.py index 7540cc9f2c26..a196c168337d 100644 --- a/torch_geometric/utils/_subgraph.py +++ b/torch_geometric/utils/_subgraph.py @@ -138,7 +138,7 @@ def subgraph( else: num_nodes = subset.size(0) node_mask = subset - subset = node_mask.nonzero().view(-1) + subset = node_mask.nonzero_static(size=num_nodes).view(-1) edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] @@ -219,7 +219,7 @@ def bipartite_subgraph( else: src_size = src_subset.size(0) src_node_mask = src_subset - src_subset = src_subset.nonzero().view(-1) + src_subset = src_subset.nonzero_static(size=src_size).view(-1) if dst_subset.dtype != torch.bool: dst_size = int(edge_index[1].max()) + 1 if size is None else size[1] @@ -227,7 +227,8 @@ def bipartite_subgraph( else: dst_size = dst_subset.size(0) dst_node_mask = dst_subset - dst_subset = dst_subset.nonzero().view(-1) + dst_subset = dst_subset.nonzero_static(size=dst_size).view(-1) + edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] diff --git a/torch_geometric/utils/mask.py b/torch_geometric/utils/mask.py index 378e98d3d946..2895058660f2 100644 --- a/torch_geometric/utils/mask.py +++ b/torch_geometric/utils/mask.py @@ -57,7 +57,7 @@ def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor: index = index.view(-1) size = int(index.max()) + 1 if size is None else size mask = index.new_zeros(size, dtype=torch.bool) - mask[index] = True + mask.scatter_(0, index, True) return mask From ea45d16b173aa9cabbdbc2d704df9ec00713b622 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:54:33 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/utils/_subgraph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_geometric/utils/_subgraph.py b/torch_geometric/utils/_subgraph.py index a196c168337d..7d2a69ec4e3a 100644 --- a/torch_geometric/utils/_subgraph.py +++ b/torch_geometric/utils/_subgraph.py @@ -229,7 +229,6 @@ def bipartite_subgraph( dst_node_mask = dst_subset dst_subset = dst_subset.nonzero_static(size=dst_size).view(-1) - edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None