Skip to content
Discussion options

You must be logged in to vote

Sorry, it's not scatter but gather:

import torch
from torch_scatter import scatter_max

x = torch.tensor([[-3, 2], [2, -3], [-2, 1], [1, -2]], dtype=torch.float)
index = torch.tensor([0, 0, 1, 1])

out, argmax = scatter_max(x.abs(), index, dim=0)
print(argmax)

out = torch.gather(x, 0, argmax)
print(out)

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@rusty1s
Comment options

@baon6052
Comment options

@baon6052
Comment options

@rusty1s
Comment options

Answer selected by baon6052
@baon6052
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants