Skip to content
Discussion options

You must be logged in to vote

I don't know of any heterogeneous graph classification benchmark dataset, that's why we don't provide an example yet (please ping me if you have any pointers).

I haven't run the code since I'm on phone currently (awesome to see you are already utilizing FakeHeteroDataset), but I think the issue is in the way you do "global pooling". IMO, there exists two options here:

  1. You aggregate the global vectors of each node type once again:
    concat = []
    for node_type in self.metadata[0]:
        concat.append(global_mean_pool(out[node_type], batch[node_type]))
    out = torch.stack(concat, dim=0).sum(dim=0)
    This results in an output shape of [num_graphs, hidden_channels].
  2. You concatenate them in the feature…

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by carlosSorcero
Comment options

You must be logged in to vote
1 reply
@rusty1s
Comment options

Comment options

You must be logged in to vote
1 reply
@rusty1s
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants