Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,40 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)

## permutedims

function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
function genperm(I::NTuple{N}, perm::NTuple{N}) where N
ntuple(d-> (@inbounds return I[perm[d]]), Val(N))
end

function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
perm isa Tuple || (perm = Tuple(perm))
gpu_call(dest, src, perm; name="permutedims!") do ctx, dest, src, perm
I = @cartesianidx src
@inbounds dest[genperm(I, perm)] = src[I]
i = @linearidx src
I = l2c(size(src), i)
@inbounds dest[c2l(size(dest), genperm(I, perm))] = src[i]
return
end
return dest
end

using Base.Cartesian
@generated function c2l(size::NTuple{N, Int}, c::NTuple{N,Int}) where N
quote
res = c[1]
stride = size[1]
@nexprs $(N-1) i->begin
res += (c[i+1]-1) * stride
stride *= size[i+1]
end
return res
end
end

@generated function l2c(size::NTuple{N, Int}, l::Int) where N
quote
l -= 1
@nexprs $(N-1) i->begin
@inbounds l, s_i = divrem(l, size[i])
end
$(Expr(:tuple, [:($(Symbol(:s_, i))+1) for i=1:N-1]..., :(l+1)))
end
end
23 changes: 22 additions & 1 deletion test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@
@test compare(x -> permutedims(x, [2,1,4,3]), AT, randn(ComplexF64,3,4,5,1))
end

@testset "c2l" begin
for i=1:100
shape = (4,rand(1:5),rand(1:7),5,19)
target = ([rand(1:s) for s in shape]...,)
@test c2l(shape, target) == LinearIndices(shape)[target...]
end
for i=1:100
shape = (4,rand(1:5),rand(1:12),15,19)
ci = CartesianIndices(shape)
i = rand(1:prod(shape))
@test l2c(shape, i) == ci[i].I
end
end

@testset "permutedims" begin
a = randn(rand(1:3, 18)...)
A = CuArray(a)
p = randperm(18)
@test Array(permutedims(A, p)) ≈ permutedims(a, p)
end

@testset "issymmetric/ishermitian" begin
n = 128
areal = randn(n,n)/2
Expand Down Expand Up @@ -114,4 +135,4 @@
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
end
end
end