@@ -2,11 +2,10 @@ using MatrixAlgebraKit
22using Test
33using TestExtras
44using StableRNGs
5- using LinearAlgebra: LinearAlgebra, I, mul!, norm
6- using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
5+ using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
76using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
87 initialize_output, AbstractAlgorithm
9- using AMDGPU
8+ using AMDGPU
109
1110# Used to test non-AbstractMatrix codepaths.
1211struct LinearMap{P<: AbstractMatrix }
5453@testset " left_orth and left_null for T = $T " for T in (Float32, Float64, ComplexF32, ComplexF64)
5554 rng = StableRNG(123 )
5655 m = 54
57- for n in (37 , m, 63 )
56+ @testset for n in (37 , m, 63 )
5857 minmn = min(m, n)
5958 A = ROCArray(randn(rng, T, m, n))
6059 V, C = @constinferred left_orth(A)
6463 @test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
6564 @test V * C ≈ A
6665 @test isisometric(V)
67- @test LinearAlgebra . norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
66+ @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
6867 @test isisometric(N)
6968 hV = collect(V)
7069 hN = collect(N)
7776 if m > n
7877 nullity = 5
7978 V, C = @constinferred left_orth(A)
80- # doesn' t work because of truncation
81- # N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
79+ AMDGPU.@allowscalar begin
80+ N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
81+ end
8282 @test V isa ROCMatrix{T} && size(V) == (m, minmn)
8383 @test C isa ROCMatrix{T} && size(C) == (minmn, n)
84- # @test N isa ROCMatrix{T} && size(N) == (m, nullity)
84+ @test N isa ROCMatrix{T} && size(N) == (m, nullity)
8585 @test V * C ≈ A
8686 @test isisometric(V)
87- # @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
88- # @test isisometric(N)
87+ @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
88+ @test isisometric(N)
8989 end
9090
9191 for alg_qr in ((; positive= true ), (; positive= false ), ROCSOLVER_HouseholderQR())
9696 @test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
9797 @test V * C ≈ A
9898 @test isisometric(V)
99- @test LinearAlgebra . norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
99+ @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
100100 @test isisometric(N)
101101 hV = collect(V)
102102 hN = collect(N)
@@ -118,33 +118,39 @@ end
118118 @test hV2 * hV2' + hN2 * hN2' ≈ I
119119
120120 atol = eps(real(T))
121- # V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
121+ V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc= (; atol= atol))
122122 N2 = @constinferred left_null!(copy!(Ac, A), N; trunc= (; atol= atol))
123- # @test V2 !== V
124- # @test C2 !== C
123+ @test V2 != = V
124+ @test C2 != = C
125125 @test N2 != = C
126- # @test V2 * C2 ≈ A
127- # @test isisometric(V2)
126+ @test V2 * C2 ≈ A
127+ @test isisometric(V2)
128128 @test LinearAlgebra. norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
129129 @test isisometric(N2)
130- #@test V2 * V2' + N2 * N2' ≈ I
130+ hV2 = collect(V2)
131+ hN2 = collect(N2)
132+ @test hV2 * hV2' + hN2 * hN2' ≈ I
131133
132134 rtol = eps(real(T))
133- for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)),
134- (TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)))
135- #V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
135+ for (trunc_orth, trunc_null) in (
136+ ((; rtol = rtol), (; rtol = rtol)),
137+ (trunctol(; rtol), trunctol(; rtol, keep_below = true)),
138+ )
139+ V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
136140 N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null)
137- # @test V2 !== V
138- # @test C2 !== C
141+ @test V2 !== V
142+ @test C2 !== C
139143 @test N2 !== C
140- # @test V2 * C2 ≈ A
141- # @test isisometric(V2)
144+ @test V2 * C2 ≈ A
145+ @test isisometric(V2)
142146 @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
143147 @test isisometric(N2)
144- # @test V2 * V2' + N2 * N2' ≈ I
148+ hV2 = collect(V2)
149+ hN2 = collect(N2)
150+ @test hV2 * hV2' + hN2 * hN2' ≈ I
145151 end
146152
147- for kind in (:qr, :polar, :svd) # explicit kind kwarg
153+ @testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
148154 m < n && kind == :polar && continue
149155 V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind= kind)
150156 @test V2 === V
@@ -163,31 +169,35 @@ end
163169
164170 # with kind and tol kwargs
165171 if kind == :svd
166- # V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
167- # trunc=(; atol=atol))
172+ V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
173+ trunc=(; atol=atol))
168174 N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
169175 trunc=(; atol=atol))
170- # @test V2 !== V
171- # @test C2 !== C
176+ @test V2 !== V
177+ @test C2 !== C
172178 @test N2 !== C
173- # @test V2 * C2 ≈ A
174- # @test V2' * V2 ≈ I
179+ @test V2 * C2 ≈ A
180+ @test V2' * V2 ≈ I
175181 @test LinearAlgebra. norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
176182 @test isisometric(N2)
177- #@test V2 * V2' + N2 * N2' ≈ I
183+ hV2 = collect(V2)
184+ hN2 = collect(N2)
185+ @test hV2 * hV2' + hN2 * hN2' ≈ I
178186
179- # V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
180- # trunc=(; rtol=rtol))
187+ V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
188+ trunc=(; rtol=rtol))
181189 N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
182190 trunc=(; rtol=rtol))
183- # @test V2 !== V
184- # @test C2 !== C
191+ @test V2 !== V
192+ @test C2 !== C
185193 @test N2 !== C
186- # @test V2 * C2 ≈ A
187- # @test isisometric(V2)
194+ @test V2 * C2 ≈ A
195+ @test isisometric(V2)
188196 @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
189197 @test isisometric(N2)
190- # @test V2 * V2' + N2 * N2' ≈ I
198+ hV2 = collect(V2)
199+ hN2 = collect(N2)
200+ @test hV2 * hV2' + hN2 * hN2' ≈ I
191201 else
192202 @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind= kind,
193203 trunc= (; atol= atol))
240250 hNᴴ2 = collect(Nᴴ2)
241251 @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I
242252
243- # TODO truncate currently broken due to searchsortedlast
244253 atol = eps(real(T))
245254 rtol = eps(real(T))
246- #= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
255+ C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc= (; atol= atol))
247256 Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc= (; atol= atol))
248257 @test C2 != = C
249258 @test Vᴴ2 != = Vᴴ
252261 @test isisometric(Vᴴ2; side= :right)
253262 @test LinearAlgebra. norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
254263 @test isisometric(Nᴴ; side= :right)
255- @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
264+ hVᴴ2 = collect(Vᴴ2)
265+ hNᴴ2 = collect(Nᴴ2)
266+ @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I
256267
257268 C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc= (; rtol= rtol))
258269 Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc= (; rtol= rtol))
@@ -263,16 +274,16 @@ end
263274 @test isisometric(Vᴴ2; side= :right)
264275 @test LinearAlgebra. norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
265276 @test isisometric(Nᴴ2; side= :right)
266- @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
267- =#
277+ hVᴴ2 = collect(Vᴴ2)
278+ hNᴴ2 = collect(Nᴴ2)
279+ @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I
268280
269281 @testset " kind = $kind " for kind in (:lq, :polar, :svd)
270282 n < m && kind == :polar && continue
271283 C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind= kind)
272284 @test C2 === C
273285 @test Vᴴ2 === Vᴴ
274- A2 = C2 * Vᴴ2
275- @test A2 ≈ A
286+ @test C2 * Vᴴ2 ≈ A
276287 @test isisometric(Vᴴ2; side= :right)
277288 if kind != :polar
278289 Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind= kind)
285296 end
286297
287298 if kind == :svd
288- # doesn't work yet because of searchsortedfirst
289- #= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
299+ C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind= kind,
290300 trunc= (; atol= atol))
291301 Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind= kind,
292302 trunc= (; atol= atol))
297307 @test isisometric(Vᴴ2; side= :right)
298308 @test LinearAlgebra. norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
299309 @test isisometric(Nᴴ2; side= :right)
300- @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
310+ hVᴴ2 = collect(Vᴴ2)
311+ hNᴴ2 = collect(Nᴴ2)
312+ @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I
301313
302314 C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind= kind,
303315 trunc= (; rtol= rtol))
310322 @test isisometric(Vᴴ2; side= :right)
311323 @test LinearAlgebra. norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit. defaulttol(T)
312324 @test isisometric(Nᴴ2; side= :right)
313- @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
314- =#
325+ hVᴴ2 = collect(Vᴴ2)
326+ hNᴴ2 = collect(Nᴴ2)
327+ @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ diagm(ones(T, size(Vᴴ2, 2 ))) atol = m* n* MatrixAlgebraKit. defaulttol(T)
315328 else
316329 @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind= kind,
317330 trunc= (; atol= atol))
0 commit comments