Skip to content

Commit 29559ac

Browse files
committed
Add tests for image and null space for GPU
1 parent ba9867b commit 29559ac

File tree

3 files changed

+735
-0
lines changed

3 files changed

+735
-0
lines changed

test/amd/orthnull.jl

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, mul!, norm
6+
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
7+
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
8+
initialize_output, AbstractAlgorithm
9+
using AMDGPU
10+
11+
# Used to test non-AbstractMatrix codepaths.
12+
struct LinearMap{P <: AbstractMatrix}
13+
parent::P
14+
end
15+
Base.parent(A::LinearMap) = getfield(A, :parent)
16+
function Base.copy!(dest::LinearMap, src::LinearMap)
17+
copy!(parent(dest), parent(src))
18+
return dest
19+
end
20+
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
21+
mul!(parent(C), parent(A), parent(B))
22+
return C
23+
end
24+
25+
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
26+
return LinearMap(copy_input(qr_compact, parent(A)))
27+
end
28+
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
29+
return LinearMap(copy_input(lq_compact, parent(A)))
30+
end
31+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
32+
return LinearMap.(initialize_output(left_orth!, parent(A)))
33+
end
34+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
35+
return LinearMap.(initialize_output(right_orth!, parent(A)))
36+
end
37+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
38+
return check_input(left_orth!, parent(A), parent.(VC), alg)
39+
end
40+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
41+
return check_input(right_orth!, parent(A), parent.(VC), alg)
42+
end
43+
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
44+
return default_svd_algorithm(A; kwargs...)
45+
end
46+
function MatrixAlgebraKit.initialize_output(
47+
::typeof(svd_compact!), A::LinearMap,
48+
alg::GPU_SVDAlgorithm
49+
)
50+
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
51+
end
52+
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
53+
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
54+
end
55+
56+
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
57+
rng = StableRNG(123)
58+
m = 54
59+
for n in (37, m, 63)
60+
minmn = min(m, n)
61+
A = ROCArray(randn(rng, T, m, n))
62+
V, C = @constinferred left_orth(A)
63+
N = @constinferred left_null(A)
64+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
65+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
66+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
67+
@test V * C A
68+
@test isisometric(V)
69+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
70+
@test isisometric(N)
71+
hV = collect(V)
72+
hN = collect(N)
73+
@test hV * hV' + hN * hN' ≈ I
74+
75+
M = LinearMap(A)
76+
VM, CM = @constinferred left_orth(M; kind = :svd)
77+
@test parent(VM) * parent(CM) ≈ A
78+
79+
if m > n
80+
nullity = 5
81+
V, C = @constinferred left_orth(A)
82+
# doesn't work because of truncation
83+
#N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
84+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
85+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
86+
#@test N isa ROCMatrix{T} && size(N) == (m, nullity)
87+
@test V * C A
88+
@test isisometric(V)
89+
#@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
90+
#@test isisometric(N)
91+
end
92+
93+
for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR())
94+
V, C = @constinferred left_orth(A; alg_qr)
95+
N = @constinferred left_null(A; alg_qr)
96+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
97+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
98+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
99+
@test V * C A
100+
@test isisometric(V)
101+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
102+
@test isisometric(N)
103+
hV = collect(V)
104+
hN = collect(N)
105+
@test hV * hV' + hN * hN' ≈ I
106+
end
107+
108+
Ac = similar(A)
109+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
110+
N2 = @constinferred left_null!(copy!(Ac, A), N)
111+
@test V2 === V
112+
@test C2 === C
113+
@test N2 === N
114+
@test V2 * C2 ≈ A
115+
@test isisometric(V2)
116+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
117+
@test isisometric(N2)
118+
hV2 = collect(V2)
119+
hN2 = collect(N2)
120+
@test hV2 * hV2' + hN2 * hN2' I
121+
122+
atol = eps(real(T))
123+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
124+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol))
125+
#@test V2 !== V
126+
#@test C2 !== C
127+
@test N2 !== C
128+
#@test V2 * C2 ≈ A
129+
#@test isisometric(V2)
130+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
131+
@test isisometric(N2)
132+
#@test V2 * V2' + N2 * N2' ≈ I
133+
134+
rtol = eps(real(T))
135+
for (trunc_orth, trunc_null) in (
136+
((; rtol = rtol), (; rtol = rtol)),
137+
(TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)),
138+
)
139+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
140+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null)
141+
#@test V2 !== V
142+
#@test C2 !== C
143+
@test N2 !== C
144+
#@test V2 * C2 ≈ A
145+
#@test isisometric(V2)
146+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
147+
@test isisometric(N2)
148+
#@test V2 * V2' + N2 * N2' ≈ I
149+
end
150+
151+
for kind in (:qr, :polar, :svd) # explicit kind kwarg
152+
m < n && kind == :polar && continue
153+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind)
154+
@test V2 === V
155+
@test C2 === C
156+
@test V2 * C2 A
157+
@test isisometric(V2)
158+
if kind != :polar
159+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind)
160+
@test N2 === N
161+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
162+
@test isisometric(N2)
163+
hV2 = collect(V2)
164+
hN2 = collect(N2)
165+
@test hV2 * hV2' + hN2 * hN2' ≈ I
166+
end
167+
168+
# with kind and tol kwargs
169+
if kind == :svd
170+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
171+
# trunc=(; atol=atol))
172+
N2 = @constinferred left_null!(
173+
copy!(Ac, A), N; kind = kind,
174+
trunc = (; atol = atol)
175+
)
176+
#@test V2 !== V
177+
#@test C2 !== C
178+
@test N2 !== C
179+
#@test V2 * C2 ≈ A
180+
#@test V2' * V2 I
181+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
182+
@test isisometric(N2)
183+
#@test V2 * V2' + N2 * N2' ≈ I
184+
185+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
186+
# trunc=(; rtol=rtol))
187+
N2 = @constinferred left_null!(
188+
copy!(Ac, A), N; kind = kind,
189+
trunc = (; rtol = rtol)
190+
)
191+
#@test V2 !== V
192+
#@test C2 !== C
193+
@test N2 !== C
194+
#@test V2 * C2 ≈ A
195+
#@test isisometric(V2)
196+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
197+
@test isisometric(N2)
198+
#@test V2 * V2' + N2 * N2' ≈ I
199+
else
200+
@test_throws ArgumentError left_orth!(
201+
copy!(Ac, A), (V, C); kind = kind,
202+
trunc = (; atol = atol)
203+
)
204+
@test_throws ArgumentError left_orth!(
205+
copy!(Ac, A), (V, C); kind = kind,
206+
trunc = (; rtol = rtol)
207+
)
208+
@test_throws ArgumentError left_null!(
209+
copy!(Ac, A), N; kind = kind,
210+
trunc = (; atol = atol)
211+
)
212+
@test_throws ArgumentError left_null!(
213+
copy!(Ac, A), N; kind = kind,
214+
trunc = (; rtol = rtol)
215+
)
216+
end
217+
end
218+
end
219+
end
220+
221+
@testset "right_orth and right_null for T = $T" for T in (
222+
Float32, Float64, ComplexF32,
223+
ComplexF64,
224+
)
225+
rng = StableRNG(123)
226+
m = 54
227+
@testset for n in (37, m, 63)
228+
minmn = min(m, n)
229+
A = ROCArray(randn(rng, T, m, n))
230+
C, Vᴴ = @constinferred right_orth(A)
231+
Nᴴ = @constinferred right_null(A)
232+
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
233+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
234+
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
235+
@test C * Vᴴ A
236+
@test isisometric(Vᴴ; side = :right)
237+
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
238+
@test isisometric(Nᴴ; side = :right)
239+
hVᴴ = collect(Vᴴ)
240+
hNᴴ = collect(Nᴴ)
241+
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ I
242+
243+
M = LinearMap(A)
244+
CM, VMᴴ = @constinferred right_orth(M; kind = :svd)
245+
@test parent(CM) * parent(VMᴴ) A
246+
247+
Ac = similar(A)
248+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
249+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
250+
@test C2 === C
251+
@test Vᴴ2 === Vᴴ
252+
@test Nᴴ2 === Nᴴ
253+
@test C2 * Vᴴ2 A
254+
@test isisometric(Vᴴ2; side = :right)
255+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
256+
@test isisometric(Nᴴ; side = :right)
257+
hVᴴ2 = collect(Vᴴ2)
258+
hNᴴ2 = collect(Nᴴ2)
259+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
260+
261+
# TODO truncate currently broken due to searchsortedlast
262+
atol = eps(real(T))
263+
rtol = eps(real(T))
264+
#=C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
265+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol))
266+
@test C2 !== C
267+
@test Vᴴ2 !== Vᴴ
268+
@test Nᴴ2 !== Nᴴ
269+
@test C2 * Vᴴ2 ≈ A
270+
@test isisometric(Vᴴ2; side=:right)
271+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
272+
@test isisometric(Nᴴ; side=:right)
273+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
274+
275+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
276+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol))
277+
@test C2 !== C
278+
@test Vᴴ2 !== Vᴴ
279+
@test Nᴴ2 !== Nᴴ
280+
@test C2 * Vᴴ2 ≈ A
281+
@test isisometric(Vᴴ2; side=:right)
282+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
283+
@test isisometric(Nᴴ2; side=:right)
284+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
285+
=#
286+
287+
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
288+
n < m && kind == :polar && continue
289+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind)
290+
@test C2 === C
291+
@test Vᴴ2 === Vᴴ
292+
A2 = C2 * Vᴴ2
293+
@test A2 A
294+
@test isisometric(Vᴴ2; side = :right)
295+
if kind != :polar
296+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind)
297+
@test Nᴴ2 === Nᴴ
298+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
299+
@test isisometric(Nᴴ2; side = :right)
300+
hVᴴ2 = collect(Vᴴ2)
301+
hNᴴ2 = collect(Nᴴ2)
302+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
303+
end
304+
305+
if kind == :svd
306+
# doesn't work yet because of searchsortedfirst
307+
#= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
308+
trunc=(; atol=atol))
309+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
310+
trunc=(; atol=atol))
311+
@test C2 !== C
312+
@test Vᴴ2 !== Vᴴ
313+
@test Nᴴ2 !== Nᴴ
314+
@test C2 * Vᴴ2 ≈ A
315+
@test isisometric(Vᴴ2; side=:right)
316+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
317+
@test isisometric(Nᴴ2; side=:right)
318+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
319+
320+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
321+
trunc=(; rtol=rtol))
322+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
323+
trunc=(; rtol=rtol))
324+
@test C2 !== C
325+
@test Vᴴ2 !== Vᴴ
326+
@test Nᴴ2 !== Nᴴ
327+
@test C2 * Vᴴ2 ≈ A
328+
@test isisometric(Vᴴ2; side=:right)
329+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
330+
@test isisometric(Nᴴ2; side=:right)
331+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
332+
=#
333+
else
334+
@test_throws ArgumentError right_orth!(
335+
copy!(Ac, A), (C, Vᴴ); kind = kind,
336+
trunc = (; atol = atol)
337+
)
338+
@test_throws ArgumentError right_orth!(
339+
copy!(Ac, A), (C, Vᴴ); kind = kind,
340+
trunc = (; rtol = rtol)
341+
)
342+
@test_throws ArgumentError right_null!(
343+
copy!(Ac, A), Nᴴ; kind = kind,
344+
trunc = (; atol = atol)
345+
)
346+
@test_throws ArgumentError right_null!(
347+
copy!(Ac, A), Nᴴ; kind = kind,
348+
trunc = (; rtol = rtol)
349+
)
350+
end
351+
end
352+
end
353+
end

0 commit comments

Comments
 (0)