Scatter version of Cross-Entropy Loss #9687
leonardcaquot94
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Context
I’m working on a node selection task where I compute logits for each node in a batch and train them using cross-entropy loss and selection probabilities. Since I select one node per graph and the number of nodes per graph varies, cross-entropy must be computed per graph using the batch tensor for grouping.
I propose adding a scatter-based cross-entropy function, similar to scatter_softmax in torch_geometric.utils, to compute one loss per graph.
Implementation
Reduction
The
reduce
parameter mirrors the behavior ofreduction
parameter intorch.nn.functional.cross_entropy
to compute a per-graph loss. An additional reduction can be applied outside this function to obtain a single scalar value if needed.Beta Was this translation helpful? Give feedback.
All reactions