Get subset of graph nodes within batch #8235
Unanswered
manavkulshrestha
asked this question in
Q&A
Replies: 1 comment
-
The easiest way to achieve this is to add a loss = loss_fn(out[batch.valid_mask], batch.y[batch.valid_mask]) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I have a question for how to do something with batching. I have the following Data objects that my dataloader yields (batch size 1):
DataBatch(x=[15, 4], edge_index=[2, 120], y=[5], batch=[15], ptr=[2])
. Now, once my model gives me the output, I only want to get the loss by comparing the node predictions on the firstlen(y)
nodes in each graph since those nodes are different and my model output is only defined for them. That is, something likeBut, this of course, doesn't really work when we have batching since multiple graph instances are concatenated with each other and the out values I care about are now interleaved. Is there some way to achieve what I need? I cannot think of a good way to do this
Beta Was this translation helpful? Give feedback.
All reactions