Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
"ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
5 changes: 5 additions & 0 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,9 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
return C
end

# TODO: intersect on GPU arrays is not working
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

end
5 changes: 5 additions & 0 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,9 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
return C
end

# TODO: intersect on GPU arrays is not working
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

end
17 changes: 9 additions & 8 deletions test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using TestExtras
using StableRNGs
using LinearAlgebra: Diagonal
using CUDA, AMDGPU
using CUDA.CUSOLVER # pull in opnorm binding

BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (BigFloat, Complex{BigFloat})
Expand All @@ -17,28 +18,28 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
TestSuite.seed_rng!(123)
if T ∈ BLASFloats
if CUDA.functional()
TestSuite.test_svd(CuMatrix{T}, (m, n); test_trunc = false)
TestSuite.test_svd(CuMatrix{T}, (m, n))
CUDA_SVD_ALGS = (
CUSOLVER_QRIteration(),
CUSOLVER_SVDPolar(),
CUSOLVER_Jacobi(),
)
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS; test_trunc = false)
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS)
if n == m
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m; test_trunc = false)
TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
TestSuite.test_svd(Diagonal{T, CuVector{T}}, m)
TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
end
end
if AMDGPU.functional()
TestSuite.test_svd(ROCMatrix{T}, (m, n); test_trunc = false)
TestSuite.test_svd(ROCMatrix{T}, (m, n))
AMD_SVD_ALGS = (
ROCSOLVER_QRIteration(),
ROCSOLVER_Jacobi(),
)
TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS; test_trunc = false)
TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS)
if n == m
TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m; test_trunc = false)
TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m)
TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
end
end
end
Expand Down
5 changes: 5 additions & 0 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), co
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))

instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1]
# AMDGPU can't generate ComplexF32 random numbers
function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz)
sqA = randn!(similar(A, real(eltype(T)), sz, sz)) .+ im .* randn!(similar(A, real(eltype(T)), sz, sz))
return qr_compact(sqA)[1]
end
instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A))))

include("qr.jl")
Expand Down
28 changes: 15 additions & 13 deletions test/testsuite/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ using TestExtras
using GenericLinearAlgebra
using LinearAlgebra: opnorm

function test_svd(T::Type, sz; test_trunc = true, kwargs...)
function test_svd(T::Type, sz; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "svd $summary_str" begin
test_svd_compact(T, sz; kwargs...)
test_svd_full(T, sz; kwargs...)
test_trunc && test_svd_trunc(T, sz; kwargs...)
test_svd_trunc(T, sz; kwargs...)
end
end

function test_svd_algs(T::Type, sz, algs; test_trunc = true, kwargs...)
function test_svd_algs(T::Type, sz, algs; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "svd algorithms $summary_str" begin
test_svd_compact_algs(T, sz, algs; kwargs...)
test_svd_full_algs(T, sz, algs; kwargs...)
test_trunc && test_svd_trunc_algs(T, sz, algs; kwargs...)
test_svd_trunc_algs(T, sz, algs; kwargs...)
end
end

Expand Down Expand Up @@ -160,14 +160,15 @@ function test_svd_trunc(
Ac = deepcopy(A)
m, n = size(A)
minmn = min(m, n)
S₀ = svd_vals(A)
S₀ = collect(svd_vals(A))
r = minmn - 2

if m > 0 && n > 0
U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r))
@test length(diagview(S1)) == r
@test diagview(S1) ≈ S₀[1:r]
@test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
@test collect(diagview(S1)) ≈ S₀[1:r]
AUSV_vals = svd_vals(A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
@test mapreduce(sv -> opnorm(sv, 2), max, AUSV_vals) ≈ S₀[r + 1]
# Test truncation error
@test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol

Expand Down Expand Up @@ -241,14 +242,15 @@ function test_svd_trunc_algs(
Ac = deepcopy(A)
m, n = size(A)
minmn = min(m, n)
S₀ = svd_vals(A)
S₀ = collect(svd_vals(A))
r = minmn - 2

if m > 0 && n > 0
U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r), alg)
@test length(diagview(S1)) == r
@test diagview(S1) ≈ S₀[1:r]
@test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
@test collect(diagview(S1)) ≈ S₀[1:r]
AUSV_vals = svd_vals(A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
@test mapreduce(sv -> opnorm(sv, 2), max, AUSV_vals) ≈ S₀[r + 1]
# Test truncation error
@test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol

Expand Down Expand Up @@ -285,11 +287,11 @@ function test_svd_trunc_algs(
)
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg)
@test length(diagview(S1)) == 1
@test diagview(S1) ≈ diagview(S)[1:1]
@test collect(diagview(S1))collect(diagview(S)[1:1])

U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 3), alg)
@test length(diagview(S2)) == 2
@test diagview(S2) ≈ diagview(S)[1:2]
@test collect(diagview(S2))collect(diagview(S)[1:2])
end
end
@testset "specify truncation algorithm" begin
Expand All @@ -303,7 +305,7 @@ function test_svd_trunc_algs(
A = U * S * Vᴴ
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = 0.2))
U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg = truncalg)
@test diagview(S2) ≈ diagview(S)[1:2]
@test collect(diagview(S2))collect(diagview(S)[1:2])
@test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol
@test_throws ArgumentError svd_trunc(A; alg = truncalg, trunc = (; maxrank = 2))
@test_throws ArgumentError svd_trunc_no_error(A; alg = truncalg, trunc = (; maxrank = 2))
Expand Down