-
Consider the following dataset: import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
datalist = []
for _ in range(100):
size = np.random.randint(10)
datalist.append( Data(x=torch.randn(size, 1)) )
dataset = InMemoryDataset()
dataset.data, dataset.slices = dataset.collate(datalist) If I batch the dataset with loader = DataLoader(dataset, batch_size=5, shuffle=True)
for batch in loader:
break Is there any way for me to recover the index of each graph in Thanks a lot! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
AFAIK, the only way to do it is to create such an |
Beta Was this translation helpful? Give feedback.
AFAIK, the only way to do it is to create such an
index
tensor. PyTorch does not seem to provide such a functionality otherwise.