Skip to content
Discussion options

You must be logged in to vote

Oh, that is interesting. batch needs indeed be in order for TopKPooling to work correctly. However, you are right that to_homogeneous() does not guarantee that for already batched HeteroData objects. I think you have two options:

  1. Call to_homogeneous() on each graph individually and use DataLoader to batch them afterwards, e.g.:
class MyTransform():
    def __call__(self, data):
        data = data.to_homogeneuous(node_attrs=['x'], add_node_type=False, add_edge_type=False)
        return data


dataset = MyHeteroDataset(..., transform=MyTransform())
loader = DataLoader(dataset, batch_size=32)
  1. Re-order your data object as part of your model:
perm = data.batch.argsort()

data.batch = data.

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@minsikseo-cdl
Comment options

@rusty1s
Comment options

@rusty1s
Comment options

@minsikseo-cdl
Comment options

Answer selected by minsikseo-cdl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants