Get original indices of selected nodes in Random Sampler #4532
Answered
by
rusty1s
itsdaniele
asked this question in
Q&A
-
Hi there! I was wondering if there is a simple way of getting the indices of the nodes that are selected for a given batch by a RandomSampler. I need that because I am trying to somehow have a connection between the full graph and the sampled subgraph. Thanks a lot in advance. |
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
Apr 26, 2022
Replies: 1 comment 1 reply
-
data.n_id = torch.arange(data.num_nodes)
loader = RandomNodeSampler(data, ...)
batch = next(iter(loader)
print(batch.n_id) should do the trick. It is also documented here. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
itsdaniele
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
should do the trick. It is also documented here.