Skip to content

Conversation

@CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Sep 30, 2021

Initial steps to fix #352

TODO

  • give more thought to the interface (which outputs do we expect?)
  • add rrule or write it in an AD friendly way
  • gpu support

end


function topk(x::AbstractArray{T,N}, k; rev=false, dims=nothing) where {T,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little odd that "top k" gives the smallest values:

julia> r = [3,11,5,13,7];

julia> topk(r, 2)
([3, 5], CartesianIndex{1}[CartesianIndex(1,), CartesianIndex(3,)])

julia> sort(r)
5-element Vector{Int64}:
  3
  5
  7
 11
 13

I understand it's following what sort does, but perhaps it needs a better name, or a different default? Is the typical use to select the largest elements?

Copy link
Member

@ToucheSir ToucheSir Oct 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch has a largest param to control this: https://pytorch.org/docs/stable/generated/torch.topk.html#torch.topk. As you noted, the default is to return the largest k values. I think passing rev=!rev to partialsortperm would do the trick.

Copy link
Member

@mcabbott mcabbott Oct 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. So they default to "largest=True" i.e. "rev=false" here.

They allow "sorted=False", but with partialsortperm, they will be sorted by default. I presume there are no downsides to getting a sorted result, it's just that they (perhaps) do something cheaper when this isn't required?

One nice thing about following sort is that you can pass all of its keywords along, by=..., lt=..., in case people want these... and you can just point at Base's documentation. It wouldn't be crazy to do that and just note that rev=true is the default here.

@mcabbott
Copy link
Member

mcabbott commented Oct 10, 2021

I think the obvious way to make this AD-able would be just rely on the gradient for getindex, which will store the indices for the backwards pass. Maybe it's as simple as this:

topk(x::AbstractArray, k::Integer; kw...) = x[topkind(x, k; kw...)]

topkind(x::AbstractVector, k::Integer; dims::Integer=1, rev=true, kw...) = (@assert dims==1; partialsortperm(x, 1:k; rev=rev, kw...))
topkind(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...) = mapslices(y -> topkind(y, k; rev=rev, kw...), x; dims=dims)
@nograd topkperm

Unlike the PR this doesn't return the permutation, but do you need it for anything else? It also defaults to first dimension, and won't accept dims=:.

But to make this GPU-friendly... There should be no aliasing issues with the gradient. There is partialsort!(::CuVector, ...) but no sortperm and no mapslices. There is a sort!(::CuMatrix; dims). Both of these call the same quicksort! so perhaps something can be built on that.

@ToucheSir
Copy link
Member

For GPU compat, TF's XLA implementation should be using all out of place ops: https://github.com/tensorflow/tensorflow/blob/8d72537c6abf5a44103b57b9c2e22c14f5f49698/tensorflow/core/tpu/kernels/topk_ops.cc. Of course we can't rely on any of the optimizations in XLA, so a more ideal implementation would probably look the PyTorch one here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/TensorTopK.cu.

@mcabbott
Copy link
Member

What I hoped might work, re-using what CUDA.jl already has, doesn't seem to -- can this easily be fixed?

julia> sort(tuple.(cu(rand(10)), 1:10), by=first)
ERROR: InvalidIRError: compiling kernel qsort_kernel(CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
 [1] bitonic_median
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:217
 [2] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:405
 [3] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
 [1] bubble_sort
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:274
 [2] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:401
 [3] qsort_kernel
   @ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(CUDA.Quicksort.qsort_kernel), Tuple{CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}}}}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/fG3xK/src/validation.jl:111
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/fG3xK/src/driver.jl:319 [inlined]

@mcabbott
Copy link
Member

mcabbott commented Oct 10, 2021

OK, this seems to work:

function topkperm(x::CuArray, k::Integer; dims::Integer=1, rev=true, lt=isless, by=identity)
    tups = tuple.(x, reshape(axes(x,dims), fill(1, dims-1)..., :))
    CUDA.quicksort!(tups; lt=(rev ? !lt : lt), by=byfirst, dims=dims, partial_k=1:k)
    tv = view(tups, ntuple(d -> d==dims ? (1:k) : (:), ndims(x))...)
    broadcast(tv, CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(1) : axes(x,d), ndims(x)))) do (_,i), J
        CartesianIndex(ntuple(d -> d==dims ? i : J[d], ndims(x)))
    end
end

# piracy to make e.g. this work:  sort(tuple.(cu(rand(10)), 1:10), by=first)
@inline Base.zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
@inline Base.one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
julia> x = 100randn(5,6)
5×6 Matrix{Float64}:
 -123.513   -106.66      25.3997   41.3127   105.062    -20.8767
  161.838     49.7304   -44.2289  -44.0282  -227.478    -62.3863
  -99.9103    90.3985  -203.78     22.0575   -14.5563  -242.797
   50.9009   120.479   -213.53     53.8734   -33.2207  -118.205
   50.3431    54.3659    49.8969  204.863   -103.487    -41.3977

julia> topk(cu(x), 2)
2×6 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 161.838   120.479   49.8969  204.863   105.062   -20.8767
  50.9009   90.3985  25.3997   53.8734  -14.5563  -41.3977

My CPU version above is not correct, since it uses indices along one slice as linear indices in the whole array. It would need to also make CartesianIndices everywhere. After which it's not just one line... is there a tidier way?

function topkperm(x::AbstractArray, k::Integer; dims::Integer=1, rev=true, kw...)
    out = similar(CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(k) : axes(x,d), ndims(x))))
    iters = ntuple(d -> d==dims ? (Colon(),) : axes(x,d), ndims(x))
    for J in Iterators.product(iters...)
        p = partialsortperm(view(x, J...), 1:k; rev=rev, kw...)
        for i in 1:k
            I = ntuple(d -> d==dims ? i : J[d], ndims(x))
            PI = ntuple(d -> d==dims ? p[i] : J[d], ndims(x))
            out[I...] = CartesianIndex(PI)
        end
    end
    out
end

@CarloLucibello
Copy link
Member Author

@mcabbott thanks for all these suggestions, feel free to push on this branch.

I like the topk - topkperm decoupling, so that it mimics base's sort functions.

Representation of the (partial) permutation as an array of CartesianIndex it's redundant, since it would be enough to
return an array of integers

topkout[i1, ..,iN, k] = x[i1, ...,iN, perm[i1, ..., iN, k]]

We could return a custom type for perm storing the integer permutation and dims. getindex could be overloaded so that we can do x[perm]. But maybe this adds complication that turns out to be not so useful in the end, so I would be totally fine with returning cartesian indexes and revisiting later if needed.

@mcabbott
Copy link
Member

Yes I don't like the redundancy of returning CartesianIndices, but I do like the simplicity of getindex. Maybe returning an array of linear indices instead would be nicer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

add topk / maxk

3 participants