I went down the rabbit hole of trying to implement argmin / findmin using AK primitives without scalar indexing.
Essentially, the solution would be this:
https://github.com/JuliaGPU/GPUArrays.jl/blob/master/src/host/indexing.jl#L229
Which requires mapreduce(f, op, itr...) combined with tuple to create each tuple (idx, val).