Skip to content
Discussion options

You must be logged in to vote

Any reason you want to maintain at all? Is this for batching reasons only?

Your output looks correct to me. You now only need to aggregate the information from node types that are present:

out = 0
for node_type in x_dict.keys():
    if x_dict[node_type].numel() == 0:
        continue
    out += global_mean_pool(x_dict[node_type], batch_dict[node_type])

Let me know if this works for you!

Replies: 2 comments 6 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
6 replies
@DogitoErgoSum
Comment options

@DogitoErgoSum
Comment options

@DogitoErgoSum
Comment options

@DogitoErgoSum
Comment options

@rusty1s
Comment options

Answer selected by DogitoErgoSum
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants