-
I encountered error: I simplified code to reproduce error. Also I commented how to bypass this error with manipulating import torch
from torch_geometric.loader import NeighborLoader
import torch.multiprocessing as mp
from torch_geometric.data import Data
def run(rank, data) -> None:
# This line will later produce error
train_idx = data.train_idx.split(data.train_idx.size(0))[0]
# Uncomment to fix error. Though `data_ptr()` and all tensor values are the same for tensor above.
# train_idx = data.train_idx
train_loader = NeighborLoader(
data=data,
input_nodes=train_idx,
num_neighbors=[5, 5],
shuffle=True,
drop_last=True,
batch_size=1,
num_workers=1,
persistent_workers=True
)
# here error will occur
print(next(iter(train_loader)))
return None
# Basic data initialization and process spawning
if __name__ == '__main__':
torch.manual_seed(0)
num_nodes = 1000000
features_dim = 10
num_edges = num_nodes * 2
train_size = num_nodes // 5
data = Data(
x=torch.rand(num_nodes, features_dim),
edge_index=torch.randint(0, num_nodes, (2, num_edges)),
train_idx=torch.randperm(train_size)
)
mp.spawn(run, args=(data, ), nprocs=1, join=True)
Environment
Thank you |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
The following codes can work
I remember Matthias mentioned that num_workers can not > 0 because cuda library limit. But I cannot find the source. |
Beta Was this translation helpful? Give feedback.
-
I think your example crashes because train_idx = data.train_idx.split(data.train_idx.size(0))[0]
train_idx = train_idx.clone() fixes this for me. |
Beta Was this translation helpful? Give feedback.
I think your example crashes because
split
creates a view, and this view is corrupted since multiple processes are trying to access it.fixes this for me.