Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question. The fundamental problem here is that scan has some unavoidable overhead on GPU (see e.g. #16106 (comment) for a general discussion) I'm not sure what to recommend beyond "don't use scan for this kind of operation on GPU".

I suspect the reason for the super-linear scaling with N is that you end up operating on intermediate arrays of size O[N] within each of the N iterations, due to the static shaping requirement.

Perhaps this would be a case where pallas could be useful? (ping @sharadmv, who might have ideas there)

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by coreyjadams
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants