Skip to content

Commit 292848f

Browse files
authored
Merge pull request #338 from JuliaGPU/tb/permutedims
Specialize permutedims kernel for the permutation.
2 parents dcc6cba + 1d533fa commit 292848f

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ steps:
55
- label: "CUDA.jl"
66
plugins:
77
- JuliaCI/julia#v0.4:
8-
version: 1.5
8+
version: nightly
99
- JuliaCI/julia-coverage#v0.2:
1010
codecov: true
1111
command: |
@@ -25,7 +25,7 @@ steps:
2525
- label: "oneAPI.jl"
2626
plugins:
2727
- JuliaCI/julia#v0.4:
28-
version: 1.5
28+
version: nightly
2929
- JuliaCI/julia-coverage#v0.2:
3030
codecov: true
3131
command: |

src/host/linalg.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,21 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
193193

194194
## permutedims
195195

196-
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
197-
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
198-
end
199-
200-
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
201-
perm isa Tuple || (perm = Tuple(perm))
202-
gpu_call(dest, src, perm; name="permutedims!") do ctx, dest, src, perm
196+
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray,
197+
perm::NTuple)
198+
Base.checkdims_perm(dest, src, perm)
199+
function permutedims_kernel(ctx, dest, src, ::Val{perm}) where {perm}
203200
I = @cartesianidx src
204-
@inbounds dest[genperm(I, perm)] = src[I]
201+
@inbounds begin
202+
J = CartesianIndex(map(i->I[i], perm))
203+
dest[J] = src[I]
204+
end
205205
return
206206
end
207+
gpu_call(permutedims_kernel, dest, src, Val(perm))
207208
return dest
208209
end
210+
211+
# TODO: implementation without the memory copy
212+
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
213+
permutedims!(dest, src, Tuple(perm))

0 commit comments

Comments
 (0)