Skip to content
Discussion options

You must be logged in to vote

Pooling can be achieved by pooling each node type and combining their graph-level outputs afterwards, e.g.:

loader = DataLoader(dataset, ...)

for data in loader:
    x_dict = {key: global_mean_pool(data[key].x, data[key].batch) for key in data.node_types}
    out = torch.stack(x_dict.values(), dim=0).sum(dim=0)

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@minsikseo-cdl
Comment options

@rusty1s
Comment options

@minsikseo-cdl
Comment options

Answer selected by minsikseo-cdl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants