Skip to content
Discussion options

You must be logged in to vote

Yes, if you want to use segment_csr, you will need to sort the matrices, and compute a new ptr tensor based on the sorted batch vector:

arange = torch.arange(int(batch.max()) + 1, device=batch.device)
ptr = torch.bucketize(arange, batch)

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@rusty1s
Comment options

@dhorka
Comment options

@rusty1s
Comment options

@dhorka
Comment options

@rusty1s
Comment options

Answer selected by dhorka
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants