@@ -193,16 +193,21 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
193
193
194
194
# # permutedims
195
195
196
- function genperm (I:: CartesianIndex{N} , perm:: NTuple{N} ) where N
197
- CartesianIndex (ntuple (d-> (@inbounds return I[perm[d]]), Val (N)))
198
- end
199
-
200
- function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) where N
201
- perm isa Tuple || (perm = Tuple (perm))
202
- gpu_call (dest, src, perm; name= " permutedims!" ) do ctx, dest, src, perm
196
+ function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray ,
197
+ perm:: NTuple )
198
+ Base. checkdims_perm (dest, src, perm)
199
+ function permutedims_kernel (ctx, dest, src, :: Val{perm} ) where {perm}
203
200
I = @cartesianidx src
204
- @inbounds dest[genperm (I, perm)] = src[I]
201
+ @inbounds begin
202
+ J = CartesianIndex (map (i-> I[i], perm))
203
+ dest[J] = src[I]
204
+ end
205
205
return
206
206
end
207
+ gpu_call (permutedims_kernel, dest, src, Val (perm))
207
208
return dest
208
209
end
210
+
211
+ # TODO : implementation without the memory copy
212
+ LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) =
213
+ permutedims! (dest, src, Tuple (perm))
0 commit comments