Skip to content
Discussion options

You must be logged in to vote

This should be straightforward to add. Instead of doing a global pooling, you simply do readout of this single node. Gradients will still flow through other nodes due to their connectivity with this central node.

data.node_index = torch.tensor([center_node], dtype=torch.long)
loader = DataLoader(...)

for data in loader:
    x = conv(data.x, data.edge_index).relu()
    x = conv(x, data.edge_index)
    x_node = x[data.node_index]

Note that data.node_index will get correctly incremented in a mini-batch scenario whenever it is called something like *_index.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@kamurani
Comment options

@rusty1s
Comment options

Answer selected by kamurani
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants