Skip to content
Discussion options

You must be logged in to vote

I suggest to use dataset transforms for that, which allows to just add the global node feature to each graph in separation:

def my_transform(data):
    mean = data.x.mean(dim=0, keepdim=True)
    data.x = torch.cat([x, mean], dim=0)
    return data

Dataset(..., transform=my_transform)

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@tferber
Comment options

@rusty1s
Comment options

@tferber
Comment options

@rusty1s
Comment options

@tferber
Comment options

Answer selected by tferber
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants