Node classification with additional "None" class and PyG batching #9095
Replies: 2 comments 1 reply
-
I might be missing something, but why can't you simply add an additional target to train against |
Beta Was this translation helpful? Give feedback.
-
Sorry for late response. My proposed solution (in the second code snippet above) assumes that I do have an additional target for none. However, since I have concatenated probs as output I also stack both targets y_none and y_node together and I calculate a single loss function (NLLLoss applied to log probs), which should be equivalent to using CrossEntropy in standard classification. Was your suggestion going in this direction or would you just calculate two separate losses? The advantage of having one loss is that it would allow taking class imbalance into account, when none is prevalent class, by calculating FocalLoss for example. I'm just wondering if there's a simpler way to handle this than what I currently do. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm working on a problem that could be specified as a node classification, where I need to account for an option where none of the nodes are "selected". Additionally, the number of nodes may vary between training examples (graphs), and there's a temporal dimension to consider. However, the graph structure remains constant across time within a single training example, with all nodes connected (edge attributes differ though).
Each training example's node representation is of shape (N, C, T), where N is the number of nodes in a graph, C is the number of channels, and T is the number of timesteps. Targets are of shape (N, T). Utilizing torch_geometric DataLoader, graphs of varying node counts can be stacked together, resulting in a shape (BN, C, T), where BN represents batched nodes from multiple graphs (BN != BS*N).
In the classification head, I utilize a Linear layer (BN, T, C) -> (BN, T) to map the C node embeddings to logits for each node. As I understand, when I classify the present nodes, I could then do:
In this approach, if any batch contains fewer nodes than the maximum possible number of nodes, those nodes effectively receive a probability of 0 after applying softmax within
CrossEntropy
loss calculation.However, this method doesn't address the possibility of a "None" classification. One idea I've considered is introducing a fake node with all-1 input and connecting it to any other node. This additional node would effectively learn a "global" embedding, indicating whether none of the real nodes should be classified.
The problem with this approach is that my model has multiple heads, and most of them do some form of regression for real nodes/edges. Having an additional fake node makes it quite cumbersome, as we have to drop it and its fake edges every time we want to make predictions on real nodes/edges.
As a potentially simpler alternative, I could employ two loss functions for this task - one for predicting binary target (let's call it y_none - it is 1 whenever none class should be predicted and 0 otherwise) and the second for classifying existing nodes. For predicting y_none, I could have a sigmoid function preceded by a global pooling operation; second, a softmax classification for the nodes. This pooling operation would aggregate node embeddings from all nodes and then map the pooled embedding to a single output using a linear layer. Finally, the node classification loss would be summed together with binary classification loss only when the ground truth target indicates one of the nodes. I could even stack all the probabilities together like below and use NLLLoss on the all_prob.
However, this is numerically less stable than usual CrossEntropy (not only because of using softmax + NLLLoss, but also mutliplying probabilities together). What's more, we need to adjust
mask
,batch
etc. by adding elements. It stills feels that a neater solution to this problem might be there.Is there a more straightforward solution that I'm overlooking?
Side question: Do I understand correctly that none of the available graph layers works with global_attr, and only some allow for exchanging information between nodes and edges by allowing edge_attr?
Beta Was this translation helpful? Give feedback.
All reactions