|
1 | 1 | using MatrixAlgebraKit |
2 | 2 | using MatrixAlgebraKit: diagview |
3 | | -using LinearAlgebra: Diagonal, isposdef |
| 3 | +using LinearAlgebra: Diagonal, isposdef, opnorm |
4 | 4 | using Test |
5 | 5 | using TestExtras |
6 | 6 | using StableRNGs |
|
96 | 96 | p = min(m, n) - k - 1 |
97 | 97 | algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k=k, p=p, niters=100),) |
98 | 98 | @testset "algorithm $alg" for alg in algs |
99 | | - #n > m && alg isa CUSOLVER_Jacobi && continue # not supported |
| 99 | + n > m && alg isa CUSOLVER_QRIteration && continue # not supported |
100 | 100 | hA = randn(rng, T, m, n) |
101 | 101 | S₀ = svd_vals(hA) |
102 | 102 | A = CuArray(hA) |
|
105 | 105 |
|
106 | 106 | U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) |
107 | 107 | @test length(S1.diag) == r |
108 | | - @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] |
| 108 | + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] |
109 | 109 |
|
110 | 110 | if !(alg isa CUSOLVER_Randomized) |
111 | 111 | s = 1 + sqrt(eps(real(T))) |
|
114 | 114 | U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) |
115 | 115 | @test length(S2.diag) == r |
116 | 116 | @test U1 ≈ U2 |
117 | | - @test S1 ≈ S2 |
| 117 | + @test parent(S1) ≈ parent(S2) |
118 | 118 | @test V1ᴴ ≈ V2ᴴ |
119 | 119 | end |
120 | | - |
121 | | - #=A = CuArray(randn(rng, T, m, n)) |
122 | | - Uref, Sref, Vᴴref = svd_full(A, CUSOLVER_SVDPolar()) |
123 | | - U, S, Vᴴ = svd_full(A; alg) |
124 | | - @test U isa CuMatrix{T} && size(U) == (m, m) |
125 | | - @test S isa CuMatrix{real(T)} && size(S) == (m, n) |
126 | | - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n) |
127 | | - for col in 1:k |
128 | | - @test view(collect(U), :, col) ≈ view(collect(Uref), :, col) |
129 | | - @test view(collect(Vᴴ), col, :) ≈ view(collect(Vᴴref), col, :) |
130 | | - end |
131 | | - @test all(isposdef, view(diagview(S), 1:k)) |
132 | | - @test view(CuArray(diagview(S)), 1:k) ≈ view(CuArray(diagview(Sref)), 1:k) |
133 | | -
|
134 | | - Ac = similar(A) |
135 | | - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) |
136 | | - @test U2 === U |
137 | | - @test S2 === S |
138 | | - @test V2ᴴ === Vᴴ |
139 | | - for col in 1:k |
140 | | - @test view(collect(U), :, col) ≈ view(collect(Uref), :, col) |
141 | | - @test view(collect(Vᴴ), col, :) ≈ view(collect(Vᴴref), col, :) |
142 | | - end |
143 | | - @test all(isposdef, view(diagview(S), 1:k)) |
144 | | - @test view(CuArray(diagview(S2)), 1:k) ≈ view(CuArray(diagview(Sref)), 1:k) |
145 | | -
|
146 | | - Sc = similar(A, real(T), k) |
147 | | - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) |
148 | | - @test Sc === Sc2 |
149 | | - @test view(Sc, 1:k) ≈ view(CuArray(diagview(Sref)), 1:k) |
150 | | - @test view(CuArray(diagview(S)), 1:k) ≈ Sc |
151 | | - # CuArray is necessary because norm of CuArray view with non-unit step is broken |
152 | | - =# |
153 | 120 | end |
154 | 121 | end |
155 | 122 | end |
0 commit comments