Skip to content

Commit d747a40

Browse files
committed
Specialize permutedims kernel for the permutation.
1 parent dcc6cba commit d747a40

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/host/linalg.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,21 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
193193

194194
## permutedims
195195

196-
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
197-
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
198-
end
199-
200-
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
201-
perm isa Tuple || (perm = Tuple(perm))
202-
gpu_call(dest, src, perm; name="permutedims!") do ctx, dest, src, perm
196+
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray,
197+
perm::NTuple)
198+
Base.checkdims_perm(dest, src, perm)
199+
function permutedims_kernel(ctx, dest, src, ::Val{perm}) where {perm}
203200
I = @cartesianidx src
204-
@inbounds dest[genperm(I, perm)] = src[I]
201+
@inbounds begin
202+
J = CartesianIndex(map(i->I[i], perm))
203+
dest[J] = src[I]
204+
end
205205
return
206206
end
207+
gpu_call(permutedims_kernel, dest, src, Val(perm))
207208
return dest
208209
end
210+
211+
# TODO: implementation without the memory copy
212+
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
213+
permutedims!(dest, src, Tuple(perm))

0 commit comments

Comments
 (0)