Out of bounds error in TopKPooling with batched graph #3753
Answered
by
rusty1s
minsikseo-cdl
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
Dec 23, 2021
Replies: 1 comment 4 replies
-
Oh, that is interesting.
class MyTransform():
def __call__(self, data):
data = data.to_homogeneuous(node_attrs=['x'], add_node_type=False, add_edge_type=False)
return data
dataset = MyHeteroDataset(..., transform=MyTransform())
loader = DataLoader(dataset, batch_size=32)
perm = data.batch.argsort()
data.batch = data.batch[perm]
data.x = data.x[perm]
data.edge_index = perm[data.edge_index] |
Beta Was this translation helpful? Give feedback.
4 replies
Answer selected by
minsikseo-cdl
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Oh, that is interesting.
batch
needs indeed be in order forTopKPooling
to work correctly. However, you are right thatto_homogeneous()
does not guarantee that for already batchedHeteroData
objects. I think you have two options:to_homogeneous()
on each graph individually and useDataLoader
to batch them afterwards, e.g.:data
object as part of your model: