Skip to content

Commit 8eb42a8

Browse files
committed
Fix transpose/adjoint.
1 parent afb4821 commit 8eb42a8

File tree

2 files changed

+39
-19
lines changed

2 files changed

+39
-19
lines changed

src/host/linalg.jl

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
# integration with LinearAlgebra stdlib
22

3-
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
3+
## transpose and adjoint
4+
5+
function transpose_f!(f, At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
46
gpu_call(At, A) do ctx, At, A
57
idx = @cartesianidx A ctx
6-
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
8+
@inbounds At[idx[2], idx[1]] = f(A[idx[1], idx[2]])
79
return
810
end
911
At
1012
end
1113

12-
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
13-
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
14+
LinearAlgebra.transpose!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, At, A)
15+
LinearAlgebra.adjoint!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, At, A)
16+
17+
function Base.copyto!(A::AbstractGPUArray, B::Adjoint{T, <: AbstractGPUArray}) where T
18+
adjoint!(A, B.parent)
1419
end
1520

16-
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
17-
perm isa Tuple || (perm = Tuple(perm))
18-
gpu_call(dest, src, perm) do ctx, dest, src, perm
19-
I = @cartesianidx src ctx
20-
@inbounds dest[genperm(I, perm)] = src[I]
21-
return
22-
end
23-
return dest
21+
function Base.copyto!(A::AbstractGPUArray, B::Transpose{T, <: AbstractGPUArray}) where T
22+
transpose!(A, B.parent)
2423
end
2524

2625
function Base.copyto!(A::AbstractArray, B::Adjoint{<:Any, <:AbstractGPUArray})
@@ -29,17 +28,17 @@ end
2928
function Base.copyto!(A::AbstractArray, B::Transpose{<:Any, <:AbstractGPUArray})
3029
copyto!(A, Transpose(Array(parent(B))))
3130
end
31+
32+
33+
## triangular
34+
3235
function Base.copyto!(A::AbstractArray, B::UpperTriangular{<:Any, <:AbstractGPUArray})
3336
copyto!(A, UpperTriangular(Array(parent(B))))
3437
end
3538
function Base.copyto!(A::AbstractArray, B::LowerTriangular{<:Any, <:AbstractGPUArray})
3639
copyto!(A, LowerTriangular(Array(parent(B))))
3740
end
3841

39-
function Base.copyto!(A::AbstractGPUArray, B::Adjoint{T, <: AbstractGPUArray}) where T
40-
transpose!(A, B.parent)
41-
end
42-
4342
function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
4443
gpu_call(A, d) do ctx, _A, _d
4544
I = @cartesianidx _A
@@ -65,8 +64,7 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
6564
end
6665

6766

68-
69-
# matrix multiplication
67+
## matrix multiplication
7068

7169
function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S,R}
7270
if size(A,2) != size(B,1)
@@ -129,3 +127,20 @@ function generic_lmul!(s::Number, X::AbstractGPUArray)
129127
end
130128

131129
LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
130+
131+
132+
## permutedims
133+
134+
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
135+
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
136+
end
137+
138+
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
139+
perm isa Tuple || (perm = Tuple(perm))
140+
gpu_call(dest, src, perm) do ctx, dest, src, perm
141+
I = @cartesianidx src ctx
142+
@inbounds dest[genperm(I, perm)] = src[I]
143+
return
144+
end
145+
return dest
146+
end

test/testsuite/linalg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
function test_linalg(AT)
22
@testset "linear algebra" begin
3-
@testset "transpose" begin
3+
@testset "adjoint and transpose" begin
44
@test compare(adjoint, AT, rand(Float32, 32, 32))
5+
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
56
@test compare(transpose, AT, rand(Float32, 32, 32))
7+
@test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
8+
@test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
9+
@test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
610
@test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))
711
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))
812
end
@@ -19,6 +23,7 @@ function test_linalg(AT)
1923
copyto!(ga, LowerTriangular(gb))
2024
@test ga == Array(collect(LowerTriangular(gb)))
2125
end
26+
2227
@testset "permutedims" begin
2328
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
2429
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))

0 commit comments

Comments
 (0)