Skip to content

Commit f79576b

Browse files
committed
Add tests for image and null space for GPU
1 parent ba9867b commit f79576b

File tree

3 files changed

+673
-0
lines changed

3 files changed

+673
-0
lines changed

test/amd/orthnull.jl

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, mul!, norm
6+
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
7+
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
8+
initialize_output, AbstractAlgorithm
9+
using AMDGPU
10+
11+
# Used to test non-AbstractMatrix codepaths.
12+
struct LinearMap{P<:AbstractMatrix}
13+
parent::P
14+
end
15+
Base.parent(A::LinearMap) = getfield(A, :parent)
16+
function Base.copy!(dest::LinearMap, src::LinearMap)
17+
copy!(parent(dest), parent(src))
18+
return dest
19+
end
20+
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
21+
mul!(parent(C), parent(A), parent(B))
22+
return C
23+
end
24+
25+
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
26+
return LinearMap(copy_input(qr_compact, parent(A)))
27+
end
28+
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
29+
return LinearMap(copy_input(lq_compact, parent(A)))
30+
end
31+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
32+
return LinearMap.(initialize_output(left_orth!, parent(A)))
33+
end
34+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
35+
return LinearMap.(initialize_output(right_orth!, parent(A)))
36+
end
37+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
38+
return check_input(left_orth!, parent(A), parent.(VC), alg)
39+
end
40+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
41+
return check_input(right_orth!, parent(A), parent.(VC), alg)
42+
end
43+
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
44+
return default_svd_algorithm(A; kwargs...)
45+
end
46+
function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap,
47+
alg::GPU_SVDAlgorithm)
48+
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
49+
end
50+
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
51+
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
52+
end
53+
54+
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
55+
rng = StableRNG(123)
56+
m = 54
57+
for n in (37, m, 63)
58+
minmn = min(m, n)
59+
A = ROCArray(randn(rng, T, m, n))
60+
V, C = @constinferred left_orth(A)
61+
N = @constinferred left_null(A)
62+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
63+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
64+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
65+
@test V * C A
66+
@test isisometric(V)
67+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
68+
@test isisometric(N)
69+
hV = collect(V)
70+
hN = collect(N)
71+
@test hV * hV' + hN * hN' ≈ I
72+
73+
M = LinearMap(A)
74+
VM, CM = @constinferred left_orth(M; kind=:svd)
75+
@test parent(VM) * parent(CM) ≈ A
76+
77+
if m > n
78+
nullity = 5
79+
V, C = @constinferred left_orth(A)
80+
# doesn't work because of truncation
81+
#N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
82+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
83+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
84+
#@test N isa ROCMatrix{T} && size(N) == (m, nullity)
85+
@test V * C A
86+
@test isisometric(V)
87+
#@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
88+
#@test isisometric(N)
89+
end
90+
91+
for alg_qr in ((; positive=true), (; positive=false), ROCSOLVER_HouseholderQR())
92+
V, C = @constinferred left_orth(A; alg_qr)
93+
N = @constinferred left_null(A; alg_qr)
94+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
95+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
96+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
97+
@test V * C A
98+
@test isisometric(V)
99+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
100+
@test isisometric(N)
101+
hV = collect(V)
102+
hN = collect(N)
103+
@test hV * hV' + hN * hN' ≈ I
104+
end
105+
106+
Ac = similar(A)
107+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
108+
N2 = @constinferred left_null!(copy!(Ac, A), N)
109+
@test V2 === V
110+
@test C2 === C
111+
@test N2 === N
112+
@test V2 * C2 ≈ A
113+
@test isisometric(V2)
114+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
115+
@test isisometric(N2)
116+
hV2 = collect(V2)
117+
hN2 = collect(N2)
118+
@test hV2 * hV2' + hN2 * hN2' I
119+
120+
atol = eps(real(T))
121+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
122+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; atol=atol))
123+
#@test V2 !== V
124+
#@test C2 !== C
125+
@test N2 !== C
126+
#@test V2 * C2 ≈ A
127+
#@test isisometric(V2)
128+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
129+
@test isisometric(N2)
130+
#@test V2 * V2' + N2 * N2' ≈ I
131+
132+
rtol = eps(real(T))
133+
for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)),
134+
(TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)))
135+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
136+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null)
137+
#@test V2 !== V
138+
#@test C2 !== C
139+
@test N2 !== C
140+
#@test V2 * C2 ≈ A
141+
#@test isisometric(V2)
142+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
143+
@test isisometric(N2)
144+
#@test V2 * V2' + N2 * N2' ≈ I
145+
end
146+
147+
for kind in (:qr, :polar, :svd) # explicit kind kwarg
148+
m < n && kind == :polar && continue
149+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind)
150+
@test V2 === V
151+
@test C2 === C
152+
@test V2 * C2 A
153+
@test isisometric(V2)
154+
if kind != :polar
155+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind)
156+
@test N2 === N
157+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
158+
@test isisometric(N2)
159+
hV2 = collect(V2)
160+
hN2 = collect(N2)
161+
@test hV2 * hV2' + hN2 * hN2' ≈ I
162+
end
163+
164+
# with kind and tol kwargs
165+
if kind == :svd
166+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
167+
# trunc=(; atol=atol))
168+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
169+
trunc=(; atol=atol))
170+
#@test V2 !== V
171+
#@test C2 !== C
172+
@test N2 !== C
173+
#@test V2 * C2 ≈ A
174+
#@test V2' * V2 I
175+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
176+
@test isisometric(N2)
177+
#@test V2 * V2' + N2 * N2' ≈ I
178+
179+
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
180+
# trunc=(; rtol=rtol))
181+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
182+
trunc=(; rtol=rtol))
183+
#@test V2 !== V
184+
#@test C2 !== C
185+
@test N2 !== C
186+
#@test V2 * C2 ≈ A
187+
#@test isisometric(V2)
188+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
189+
@test isisometric(N2)
190+
#@test V2 * V2' + N2 * N2' ≈ I
191+
else
192+
@test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind,
193+
trunc=(; atol=atol))
194+
@test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind,
195+
trunc=(; rtol=rtol))
196+
@test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind,
197+
trunc=(; atol=atol))
198+
@test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind,
199+
trunc=(; rtol=rtol))
200+
end
201+
end
202+
end
203+
end
204+
205+
@testset "right_orth and right_null for T = $T" for T in (Float32, Float64, ComplexF32,
206+
ComplexF64)
207+
rng = StableRNG(123)
208+
m = 54
209+
@testset for n in (37, m, 63)
210+
minmn = min(m, n)
211+
A = ROCArray(randn(rng, T, m, n))
212+
C, Vᴴ = @constinferred right_orth(A)
213+
Nᴴ = @constinferred right_null(A)
214+
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
215+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
216+
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
217+
@test C * Vᴴ A
218+
@test isisometric(Vᴴ; side=:right)
219+
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
220+
@test isisometric(Nᴴ; side=:right)
221+
hVᴴ = collect(Vᴴ)
222+
hNᴴ = collect(Nᴴ)
223+
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ I
224+
225+
M = LinearMap(A)
226+
CM, VMᴴ = @constinferred right_orth(M; kind=:svd)
227+
@test parent(CM) * parent(VMᴴ) A
228+
229+
Ac = similar(A)
230+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
231+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
232+
@test C2 === C
233+
@test Vᴴ2 === Vᴴ
234+
@test Nᴴ2 === Nᴴ
235+
@test C2 * Vᴴ2 A
236+
@test isisometric(Vᴴ2; side=:right)
237+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
238+
@test isisometric(Nᴴ; side=:right)
239+
hVᴴ2 = collect(Vᴴ2)
240+
hNᴴ2 = collect(Nᴴ2)
241+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
242+
243+
# TODO truncate currently broken due to searchsortedlast
244+
atol = eps(real(T))
245+
rtol = eps(real(T))
246+
#=C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
247+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol))
248+
@test C2 !== C
249+
@test Vᴴ2 !== Vᴴ
250+
@test Nᴴ2 !== Nᴴ
251+
@test C2 * Vᴴ2 ≈ A
252+
@test isisometric(Vᴴ2; side=:right)
253+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
254+
@test isisometric(Nᴴ; side=:right)
255+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
256+
257+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
258+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol))
259+
@test C2 !== C
260+
@test Vᴴ2 !== Vᴴ
261+
@test Nᴴ2 !== Nᴴ
262+
@test C2 * Vᴴ2 ≈ A
263+
@test isisometric(Vᴴ2; side=:right)
264+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
265+
@test isisometric(Nᴴ2; side=:right)
266+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
267+
=#
268+
269+
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
270+
n < m && kind == :polar && continue
271+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind)
272+
@test C2 === C
273+
@test Vᴴ2 === Vᴴ
274+
A2 = C2 * Vᴴ2
275+
@test A2 A
276+
@test isisometric(Vᴴ2; side=:right)
277+
if kind != :polar
278+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind)
279+
@test Nᴴ2 === Nᴴ
280+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
281+
@test isisometric(Nᴴ2; side=:right)
282+
hVᴴ2 = collect(Vᴴ2)
283+
hNᴴ2 = collect(Nᴴ2)
284+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
285+
end
286+
287+
if kind == :svd
288+
# doesn't work yet because of searchsortedfirst
289+
#= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
290+
trunc=(; atol=atol))
291+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
292+
trunc=(; atol=atol))
293+
@test C2 !== C
294+
@test Vᴴ2 !== Vᴴ
295+
@test Nᴴ2 !== Nᴴ
296+
@test C2 * Vᴴ2 ≈ A
297+
@test isisometric(Vᴴ2; side=:right)
298+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
299+
@test isisometric(Nᴴ2; side=:right)
300+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
301+
302+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
303+
trunc=(; rtol=rtol))
304+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
305+
trunc=(; rtol=rtol))
306+
@test C2 !== C
307+
@test Vᴴ2 !== Vᴴ
308+
@test Nᴴ2 !== Nᴴ
309+
@test C2 * Vᴴ2 ≈ A
310+
@test isisometric(Vᴴ2; side=:right)
311+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
312+
@test isisometric(Nᴴ2; side=:right)
313+
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
314+
=#
315+
else
316+
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
317+
trunc=(; atol=atol))
318+
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
319+
trunc=(; rtol=rtol))
320+
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,
321+
trunc=(; atol=atol))
322+
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,
323+
trunc=(; rtol=rtol))
324+
end
325+
end
326+
end
327+
end

0 commit comments

Comments
 (0)