diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 2282528e..aa8806a3 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -14,6 +14,7 @@ function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector) end function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix) axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint")) + isempty(A) && return B @kernel function adjoint_kernel!(B, A) idx = @index(Global, Linear) @inbounds B[idx] = adjoint(A[1, idx]) @@ -23,6 +24,7 @@ function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix) end function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector) axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint")) + isempty(A) && return B @kernel function adjoint_kernel!(B, A) idx = @index(Global, Linear) @inbounds B[1, idx] = adjoint(A[idx]) @@ -35,6 +37,8 @@ LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpos LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A) function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f))) + # array with size zero dimension + isempty(A) && return B @kernel function transpose_kernel!(B, A) idx = @index(Global, Cartesian) @inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]]) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 5770c85f..7ecd61a9 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -4,10 +4,14 @@ @test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32)) @test compare(adjoint!, AT, rand(Float32, 1, 32), rand(Float32, 32)) @test compare(adjoint!, AT, rand(Float32, 32), rand(Float32, 1, 32)) + @test compare(adjoint!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32)) + @test compare(adjoint!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0)) @test compare(transpose, AT, rand(Float32, 32, 32)) @test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32)) @test compare(transpose!, AT, rand(Float32, 1, 32), rand(Float32, 32)) @test compare(transpose!, AT, rand(Float32, 32), rand(Float32, 1, 32)) + @test compare(transpose!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32)) + @test compare(transpose!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0)) @test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32)) @test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32)) @test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))