Skip to content
Discussion options

You must be logged in to vote

I'm not sure if this is the "best" solution, but here:
Use Gumbel re-parameterization trick:

  • transform probabilities into log probabilities (log p)
  • sample a vector from the Gumbel Distribution (v)
  • add them (log p + v)
  • use scatter with a reduction='max' to get the maximum of each graph (argmax log p + v)
  • Now you have the indexes of the sampled nodes

I think it's a good idea to incorporate something like that into the Batch class

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by moe-assal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant