Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,8 @@ function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Numb
i = @index(Global, Linear)
@inbounds xi = x[i]
@inbounds yi = y[i]
@inbounds x[i] = c * xi + s * yi
@inbounds y[i] = -conj(s) * xi + c * yi
@inbounds x[i] = s*yi + c *xi
@inbounds y[i] = c*yi - conj(s)*xi
end
rotate_kernel!(get_backend(x))(x, y, c, s; ndrange = size(x))
return x, y
Expand Down
14 changes: 13 additions & 1 deletion test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,24 @@ function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
end

function compare(f, AT::Type{<:Array}, xs...; kwargs...)
# no need to actually run this tests: we have nothing to compoare against,
# no need to actually run this tests: we have nothing to compare against,
# and we'll run it on a CPU array anyhow when comparing to a GPU array.
#
# this method exists so that we can at least run the test suite with Array,
# and make sure we cover other tests (that don't call `compare`) too.
return true
end

has_NaNs(a::AbstractArray) = isfloattype(eltype(a)) && any(isnan, collect(a))
has_NaNs(as::NTuple) = any(a -> has_NaNs(a), as)

out_has_NaNs(f, AT::Type{<:Array}, xs...) = false # we do not test stdlibs/LinAlg for NaNs (maybe they should?)
function out_has_NaNs(f, AT::Type{<:AbstractGPUArray}, xs...)
arg_in = map(x -> isa(x, Base.RefValue) ? x[] : adapt(AT, x), xs)
arg_out = f(arg_in...)
return has_NaNs(arg_out)
end

# element types that are supported by the array type
supported_eltypes(AT, test) = supported_eltypes(AT)
supported_eltypes(AT) = supported_eltypes()
Expand All @@ -67,6 +77,8 @@ isrealtype(T) = T <: Real
iscomplextype(T) = T <: Complex
isrealfloattype(T) = T <: AbstractFloat
isfloattype(T) = T <: AbstractFloat || T <: Complex{<:AbstractFloat}
NaN_T(T::Type{<:AbstractFloat}) = T(NaN)
NaN_T(T::Type{<:Complex{<:AbstractFloat}}) = T(NaN, NaN)

# list of tests
const tests = Dict()
Expand Down
27 changes: 27 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,19 @@
@testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in eltypes
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
if isfloattype(T)
@test compare(rmul!, AT, fill(NaN_T(T), a), Ref(false))
@test compare(lmul!, AT, Ref(false), fill(NaN_T(T), b))
end
end

@testset "axp{b}y" for T in eltypes
@test compare(axpby!, AT, Ref(rand(T)), rand(T,5), Ref(rand(T)), rand(T,5))
@test compare(axpy!, AT, Ref(rand(T)), rand(T,5), rand(T,5))
if isfloattype(T)
@test compare(axpby!, AT, Ref(false), fill(NaN_T(T), 5), Ref(false), fill(NaN_T(T), 5))
@test compare(axpy!, AT, Ref(false), fill(NaN_T(T), 5), rand(T, 5))
end
end

@testset "dot" for T in eltypes
Expand All @@ -295,10 +303,18 @@

@testset "rotate!" for T in eltypes
@test compare(rotate!, AT, rand(T,5), rand(T,5), Ref(rand(real(T))), Ref(rand(T)))
if isfloattype(T)
# skip compare until https://github.com/JuliaLang/LinearAlgebra.jl/pull/1323 is released and only check correct strong zero behaviour of AbstractGPUArray
# @test compare(rotate!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
@test !out_has_NaNs(rotate!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
end
end

@testset "reflect!" for T in eltypes
@test compare(reflect!, AT, rand(T,5), rand(T,5), Ref(rand(real(T))), Ref(rand(T)))
if isfloattype(T)
@test compare(reflect!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
end
end

@testset "iszero and isone" for T in eltypes
Expand Down Expand Up @@ -330,6 +346,13 @@ end
@test compare(*, AT, f(A), x)
@test compare(mul!, AT, y, f(A), x)
@test compare(mul!, AT, y, f(A), x, Ref(T(4)), Ref(T(5)))
if isfloattype(T)
y_NaN, A_NaN, x_NaN = fill(NaN_T(T), 4), fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4)
if !(T==Float16) && !(T == ComplexF16) # skip Float16/ComplexF16 until https://github.com/JuliaLang/LinearAlgebra.jl/issues/1399 is fixed and only check correct strong zero behaviour of AbstractGPUArray
@test compare(mul!, AT, y_NaN, f(A_NaN), x_NaN, Ref(false), Ref(false))
end
@test !out_has_NaNs(mul!, AT, y_NaN, f(A_NaN), x_NaN, Ref(false), Ref(false))
end
@test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3))) <: AbstractVector

if f !== identity
Expand All @@ -348,6 +371,10 @@ end
@test compare(*, AT, f(A), g(B))
@test compare(mul!, AT, C, f(A), g(B))
@test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5)))
if isfloattype(T)
A_NaN, B_NaN, C_NaN = fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4, 4)
@test compare(mul!, AT, C_NaN, f(A_NaN), g(B_NaN), Ref(false), Ref(false))
end
@test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix
end
end
Expand Down
Loading