diff --git a/torch_geometric/utils/_subgraph.py b/torch_geometric/utils/_subgraph.py index 7540cc9f2c26..7d2a69ec4e3a 100644 --- a/torch_geometric/utils/_subgraph.py +++ b/torch_geometric/utils/_subgraph.py @@ -138,7 +138,7 @@ def subgraph( else: num_nodes = subset.size(0) node_mask = subset - subset = node_mask.nonzero().view(-1) + subset = node_mask.nonzero_static(size=num_nodes).view(-1) edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] @@ -219,7 +219,7 @@ def bipartite_subgraph( else: src_size = src_subset.size(0) src_node_mask = src_subset - src_subset = src_subset.nonzero().view(-1) + src_subset = src_subset.nonzero_static(size=src_size).view(-1) if dst_subset.dtype != torch.bool: dst_size = int(edge_index[1].max()) + 1 if size is None else size[1] @@ -227,7 +227,7 @@ def bipartite_subgraph( else: dst_size = dst_subset.size(0) dst_node_mask = dst_subset - dst_subset = dst_subset.nonzero().view(-1) + dst_subset = dst_subset.nonzero_static(size=dst_size).view(-1) edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] diff --git a/torch_geometric/utils/mask.py b/torch_geometric/utils/mask.py index 378e98d3d946..2895058660f2 100644 --- a/torch_geometric/utils/mask.py +++ b/torch_geometric/utils/mask.py @@ -57,7 +57,7 @@ def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor: index = index.view(-1) size = int(index.max()) + 1 if size is None else size mask = index.new_zeros(size, dtype=torch.bool) - mask[index] = True + mask.scatter_(0, index, True) return mask