-
I am building a classifier that predicts the interaction (yes / no) between two proteins, each represented as a graph of amino acids (each node has some feature vector h). Each graph passes through its own set of GCN/GAT layers, and then the output embedding vectors from each are concatenated. This concatenated vector then gets input to the 'classifier' part of the network, which is just 2 fully connected layers and 1 output layer. The thing is, for this "graph classification" task, I have a specific node of interest whose embedding I want to consider as the whole graph's output. You can think of this as each graph having a specific node at the "centre", and this centre node is what I want the classifier to use (along with the other graph's embedding, concatenated). This is not simply an additional "special" node artifically added like I've seen in some graph classification tasks -- the 'centre node' is actually one of the aready existing nodes in the graph structure. I have the ID of this node for every training / testing example; and when the model is used on unseen data, the ID of the node that's at the "centre" of the input graph will also be known. So it will never be ambiguous which node should be designated the "centre node", for any use case my model will be used for. I guess another way of describing what my aim is, is to do a node classification task for n graphs, where each graph has only 1 node with a label; the rest of the nodes are not considered / labelled. However, I want to concatenate this node's embedding after (GCN/GAT layers) with another graph's output embedding before feeding it through a classifier and then producing the desired output label. Alternatively, you could consider what I want to do as a graph classification problem; but one of the graphs is classified based on the "point of view" of a specific node. I am not sure if this structure would even be differentiable for the loss function. Any ideas on how to set this up would be greatly appreciated! Thanks so much in advance everyone |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
This should be straightforward to add. Instead of doing a global pooling, you simply do readout of this single node. Gradients will still flow through other nodes due to their connectivity with this central node. data.node_index = torch.tensor([center_node], dtype=torch.long)
loader = DataLoader(...)
for data in loader:
x = conv(data.x, data.edge_index).relu()
x = conv(x, data.edge_index)
x_node = x[data.node_index] Note that |
Beta Was this translation helpful? Give feedback.
This should be straightforward to add. Instead of doing a global pooling, you simply do readout of this single node. Gradients will still flow through other nodes due to their connectivity with this central node.
Note that
data.node_index
will get correctly incremented in a mini-batch scenario whenever it is called something like*_index
.