Skip to content

Commit 34839bd

Browse files
kshyattKatharine Hyatt
authored andcommitted
Add tests for image and null space for GPU
1 parent ea14b0c commit 34839bd

File tree

4 files changed

+770
-1
lines changed

4 files changed

+770
-1
lines changed

src/implementations/truncation.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
1717
# TODO: avoid allocation?
1818
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
1919
ind = findtruncated(extended_S, strategy)
20-
return U[:, ind], ind
20+
trunc_cols = collect(1:size(U, 2))[ind]
21+
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
22+
Utrunc .= U[:, trunc_cols]
23+
return Utrunc, ind
2124
end
2225
function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
2326
# TODO: avoid allocation?

test/amd/orthnull.jl

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
6+
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
7+
initialize_output, AbstractAlgorithm
8+
using AMDGPU
9+
10+
# Used to test non-AbstractMatrix codepaths.
11+
struct LinearMap{P <: AbstractMatrix}
12+
parent::P
13+
end
14+
Base.parent(A::LinearMap) = getfield(A, :parent)
15+
function Base.copy!(dest::LinearMap, src::LinearMap)
16+
copy!(parent(dest), parent(src))
17+
return dest
18+
end
19+
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
20+
mul!(parent(C), parent(A), parent(B))
21+
return C
22+
end
23+
24+
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
25+
return LinearMap(copy_input(qr_compact, parent(A)))
26+
end
27+
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
28+
return LinearMap(copy_input(lq_compact, parent(A)))
29+
end
30+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
31+
return LinearMap.(initialize_output(left_orth!, parent(A)))
32+
end
33+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
34+
return LinearMap.(initialize_output(right_orth!, parent(A)))
35+
end
36+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
37+
return check_input(left_orth!, parent(A), parent.(VC), alg)
38+
end
39+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
40+
return check_input(right_orth!, parent(A), parent.(VC), alg)
41+
end
42+
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
43+
return default_svd_algorithm(A; kwargs...)
44+
end
45+
function MatrixAlgebraKit.initialize_output(
46+
::typeof(svd_compact!), A::LinearMap,
47+
alg::GPU_SVDAlgorithm
48+
)
49+
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
50+
end
51+
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
52+
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
53+
end
54+
55+
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
56+
rng = StableRNG(123)
57+
m = 54
58+
@testset for n in (37, m, 63)
59+
minmn = min(m, n)
60+
A = ROCArray(randn(rng, T, m, n))
61+
V, C = @constinferred left_orth(A)
62+
N = @constinferred left_null(A)
63+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
64+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
65+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
66+
@test V * C A
67+
@test isisometric(V)
68+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
69+
@test isisometric(N)
70+
hV = collect(V)
71+
hN = collect(N)
72+
@test hV * hV' + hN * hN' ≈ I
73+
74+
M = LinearMap(A)
75+
VM, CM = @constinferred left_orth(M; kind = :svd)
76+
@test parent(VM) * parent(CM) ≈ A
77+
78+
if m > n
79+
nullity = 5
80+
V, C = @constinferred left_orth(A)
81+
AMDGPU.@allowscalar begin
82+
N = @constinferred left_null(A; trunc = (; maxnullity = nullity))
83+
end
84+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
85+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
86+
@test N isa ROCMatrix{T} && size(N) == (m, nullity)
87+
@test V * C ≈ A
88+
@test isisometric(V)
89+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
90+
@test isisometric(N)
91+
end
92+
93+
for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR())
94+
V, C = @constinferred left_orth(A; alg_qr)
95+
N = @constinferred left_null(A; alg_qr)
96+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
97+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
98+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
99+
@test V * C A
100+
@test isisometric(V)
101+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
102+
@test isisometric(N)
103+
hV = collect(V)
104+
hN = collect(N)
105+
@test hV * hV' + hN * hN' ≈ I
106+
end
107+
108+
Ac = similar(A)
109+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
110+
N2 = @constinferred left_null!(copy!(Ac, A), N)
111+
@test V2 === V
112+
@test C2 === C
113+
@test N2 === N
114+
@test V2 * C2 ≈ A
115+
@test isisometric(V2)
116+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
117+
@test isisometric(N2)
118+
hV2 = collect(V2)
119+
hN2 = collect(N2)
120+
@test hV2 * hV2' + hN2 * hN2' I
121+
122+
atol = eps(real(T))
123+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol))
124+
AMDGPU.@allowscalar begin
125+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol))
126+
end
127+
@test V2 !== V
128+
@test C2 !== C
129+
@test N2 !== C
130+
@test V2 * C2 A
131+
@test isisometric(V2)
132+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
133+
@test isisometric(N2)
134+
hV2 = collect(V2)
135+
hN2 = collect(N2)
136+
@test hV2 * hV2' + hN2 * hN2' ≈ I
137+
138+
rtol = eps(real(T))
139+
for (trunc_orth, trunc_null) in (
140+
((; rtol = rtol), (; rtol = rtol)),
141+
(trunctol(; rtol), trunctol(; rtol, keep_below = true)),
142+
)
143+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth)
144+
AMDGPU.@allowscalar begin
145+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null)
146+
end
147+
@test V2 !== V
148+
@test C2 !== C
149+
@test N2 !== C
150+
@test V2 * C2 ≈ A
151+
@test isisometric(V2)
152+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
153+
@test isisometric(N2)
154+
hV2 = collect(V2)
155+
hN2 = collect(N2)
156+
@test hV2 * hV2' + hN2 * hN2' I
157+
end
158+
159+
@testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
160+
m < n && kind == :polar && continue
161+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind)
162+
@test V2 === V
163+
@test C2 === C
164+
@test V2 * C2 A
165+
@test isisometric(V2)
166+
if kind != :polar
167+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind)
168+
@test N2 === N
169+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
170+
@test isisometric(N2)
171+
hV2 = collect(V2)
172+
hN2 = collect(N2)
173+
@test hV2 * hV2' + hN2 * hN2' ≈ I
174+
end
175+
176+
# with kind and tol kwargs
177+
if kind == :svd
178+
V2, C2 = @constinferred left_orth!(
179+
copy!(Ac, A), (V, C); kind = kind,
180+
trunc = (; atol = atol)
181+
)
182+
AMDGPU.@allowscalar begin
183+
N2 = @constinferred left_null!(
184+
copy!(Ac, A), N; kind = kind,
185+
trunc = (; atol = atol)
186+
)
187+
end
188+
@test V2 !== V
189+
@test C2 !== C
190+
@test N2 !== C
191+
@test V2 * C2 ≈ A
192+
@test V2' * V2 I
193+
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
194+
@test isisometric(N2)
195+
hV2 = collect(V2)
196+
hN2 = collect(N2)
197+
@test hV2 * hV2' + hN2 * hN2' ≈ I
198+
199+
V2, C2 = @constinferred left_orth!(
200+
copy!(Ac, A), (V, C); kind = kind,
201+
trunc = (; rtol = rtol)
202+
)
203+
AMDGPU.@allowscalar begin
204+
N2 = @constinferred left_null!(
205+
copy!(Ac, A), N; kind = kind,
206+
trunc = (; rtol = rtol)
207+
)
208+
end
209+
@test V2 !== V
210+
@test C2 !== C
211+
@test N2 !== C
212+
@test V2 * C2 ≈ A
213+
@test isisometric(V2)
214+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
215+
@test isisometric(N2)
216+
hV2 = collect(V2)
217+
hN2 = collect(N2)
218+
@test hV2 * hV2' + hN2 * hN2' I
219+
else
220+
@test_throws ArgumentError left_orth!(
221+
copy!(Ac, A), (V, C); kind = kind,
222+
trunc = (; atol = atol)
223+
)
224+
@test_throws ArgumentError left_orth!(
225+
copy!(Ac, A), (V, C); kind = kind,
226+
trunc = (; rtol = rtol)
227+
)
228+
@test_throws ArgumentError left_null!(
229+
copy!(Ac, A), N; kind = kind,
230+
trunc = (; atol = atol)
231+
)
232+
@test_throws ArgumentError left_null!(
233+
copy!(Ac, A), N; kind = kind,
234+
trunc = (; rtol = rtol)
235+
)
236+
end
237+
end
238+
end
239+
end
240+
241+
@testset "right_orth and right_null for T = $T" for T in (
242+
Float32, Float64, ComplexF32,
243+
ComplexF64,
244+
)
245+
rng = StableRNG(123)
246+
m = 54
247+
@testset for n in (37, m, 63)
248+
minmn = min(m, n)
249+
A = ROCArray(randn(rng, T, m, n))
250+
C, Vᴴ = @constinferred right_orth(A)
251+
Nᴴ = @constinferred right_null(A)
252+
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
253+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
254+
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
255+
@test C * Vᴴ A
256+
@test isisometric(Vᴴ; side = :right)
257+
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
258+
@test isisometric(Nᴴ; side = :right)
259+
hVᴴ = collect(Vᴴ)
260+
hNᴴ = collect(Nᴴ)
261+
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ I
262+
263+
M = LinearMap(A)
264+
CM, VMᴴ = @constinferred right_orth(M; kind = :svd)
265+
@test parent(CM) * parent(VMᴴ) A
266+
267+
Ac = similar(A)
268+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
269+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
270+
@test C2 === C
271+
@test Vᴴ2 === Vᴴ
272+
@test Nᴴ2 === Nᴴ
273+
@test C2 * Vᴴ2 A
274+
@test isisometric(Vᴴ2; side = :right)
275+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
276+
@test isisometric(Nᴴ; side = :right)
277+
hVᴴ2 = collect(Vᴴ2)
278+
hNᴴ2 = collect(Nᴴ2)
279+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
280+
281+
atol = eps(real(T))
282+
rtol = eps(real(T))
283+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol))
284+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol))
285+
@test C2 !== C
286+
@test Vᴴ2 !== Vᴴ
287+
@test Nᴴ2 !== Nᴴ
288+
@test C2 * Vᴴ2 A
289+
@test isisometric(Vᴴ2; side = :right)
290+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
291+
@test isisometric(Nᴴ; side = :right)
292+
hVᴴ2 = collect(Vᴴ2)
293+
hNᴴ2 = collect(Nᴴ2)
294+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
295+
296+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol))
297+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol))
298+
@test C2 !== C
299+
@test Vᴴ2 !== Vᴴ
300+
@test Nᴴ2 !== Nᴴ
301+
@test C2 * Vᴴ2 A
302+
@test isisometric(Vᴴ2; side = :right)
303+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
304+
@test isisometric(Nᴴ2; side = :right)
305+
hVᴴ2 = collect(Vᴴ2)
306+
hNᴴ2 = collect(Nᴴ2)
307+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
308+
309+
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
310+
n < m && kind == :polar && continue
311+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind)
312+
@test C2 === C
313+
@test Vᴴ2 === Vᴴ
314+
@test C2 * Vᴴ2 A
315+
@test isisometric(Vᴴ2; side = :right)
316+
if kind != :polar
317+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind)
318+
@test Nᴴ2 === Nᴴ
319+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
320+
@test isisometric(Nᴴ2; side = :right)
321+
hVᴴ2 = collect(Vᴴ2)
322+
hNᴴ2 = collect(Nᴴ2)
323+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
324+
end
325+
326+
if kind == :svd
327+
C2, Vᴴ2 = @constinferred right_orth!(
328+
copy!(Ac, A), (C, Vᴴ); kind = kind,
329+
trunc = (; atol = atol)
330+
)
331+
Nᴴ2 = @constinferred right_null!(
332+
copy!(Ac, A), Nᴴ; kind = kind,
333+
trunc = (; atol = atol)
334+
)
335+
@test C2 !== C
336+
@test Vᴴ2 !== Vᴴ
337+
@test Nᴴ2 !== Nᴴ
338+
@test C2 * Vᴴ2 A
339+
@test isisometric(Vᴴ2; side = :right)
340+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
341+
@test isisometric(Nᴴ2; side = :right)
342+
hVᴴ2 = collect(Vᴴ2)
343+
hNᴴ2 = collect(Nᴴ2)
344+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
345+
346+
C2, Vᴴ2 = @constinferred right_orth!(
347+
copy!(Ac, A), (C, Vᴴ); kind = kind,
348+
trunc = (; rtol = rtol)
349+
)
350+
Nᴴ2 = @constinferred right_null!(
351+
copy!(Ac, A), Nᴴ; kind = kind,
352+
trunc = (; rtol = rtol)
353+
)
354+
@test C2 !== C
355+
@test Vᴴ2 !== Vᴴ
356+
@test Nᴴ2 !== Nᴴ
357+
@test C2 * Vᴴ2 A
358+
@test isisometric(Vᴴ2; side = :right)
359+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
360+
@test isisometric(Nᴴ2; side = :right)
361+
hVᴴ2 = collect(Vᴴ2)
362+
hNᴴ2 = collect(Nᴴ2)
363+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T)
364+
else
365+
@test_throws ArgumentError right_orth!(
366+
copy!(Ac, A), (C, Vᴴ); kind = kind,
367+
trunc = (; atol = atol)
368+
)
369+
@test_throws ArgumentError right_orth!(
370+
copy!(Ac, A), (C, Vᴴ); kind = kind,
371+
trunc = (; rtol = rtol)
372+
)
373+
@test_throws ArgumentError right_null!(
374+
copy!(Ac, A), Nᴴ; kind = kind,
375+
trunc = (; atol = atol)
376+
)
377+
@test_throws ArgumentError right_null!(
378+
copy!(Ac, A), Nᴴ; kind = kind,
379+
trunc = (; rtol = rtol)
380+
)
381+
end
382+
end
383+
end
384+
end

0 commit comments

Comments
 (0)