Skip to content

Commit 4fbc3bf

Browse files
authored
Add tests for image and null space for GPU (#82)
* Add tests for image and null space for GPU
1 parent 50eb537 commit 4fbc3bf

File tree

6 files changed

+745
-48
lines changed

6 files changed

+745
-48
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,14 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
161161
return A, B
162162
end
163163

164+
function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS}
165+
# TODO: avoid allocation?
166+
U, S = US
167+
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
168+
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
169+
trunc_cols = collect(1:size(U, 2))[ind]
170+
Utrunc = U[:, trunc_cols]
171+
return Utrunc, ind
172+
end
173+
164174
end

test/amd/orthnull.jl

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
6+
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
7+
initialize_output, AbstractAlgorithm
8+
using AMDGPU
9+
10+
# testing non-AbstractArray codepaths:
11+
include(joinpath("..", "linearmap.jl"))
12+
13+
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
14+
rng = StableRNG(123)
15+
m = 54
16+
@testset for n in (37, m, 63)
17+
minmn = min(m, n)
18+
A = ROCArray(randn(rng, T, m, n))
19+
V, C = @constinferred left_orth(A)
20+
N = @constinferred left_null(A)
21+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
22+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
23+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
24+
@test V * C A
25+
@test isisometric(V)
26+
@test norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
27+
@test isisometric(N)
28+
hV = collect(V)
29+
hN = collect(N)
30+
@test hV * hV' + hN * hN' I
31+
32+
M = LinearMap(A)
33+
VM, CM = @constinferred left_orth(M; kind = :svd)
34+
@test parent(VM) * parent(CM) A
35+
36+
if m > n
37+
nullity = 5
38+
V, C = @constinferred left_orth(A)
39+
AMDGPU.@allowscalar begin
40+
N = @constinferred left_null(A; trunc = (; maxnullity = nullity))
41+
end
42+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
43+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
44+
@test N isa ROCMatrix{T} && size(N) == (m, nullity)
45+
@test V * C A
46+
@test isisometric(V)
47+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
48+
@test isisometric(N)
49+
end
50+
51+
for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR())
52+
V, C = @constinferred left_orth(A; alg_qr)
53+
N = @constinferred left_null(A; alg_qr)
54+
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
55+
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
56+
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
57+
@test V * C A
58+
@test isisometric(V)
59+
@test norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
60+
@test isisometric(N)
61+
hV = collect(V)
62+
hN = collect(N)
63+
@test hV * hV' + hN * hN' I
64+
end
65+
66+
Ac = similar(A)
67+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
68+
N2 = @constinferred left_null!(copy!(Ac, A), N)
69+
@test V2 === V
70+
@test C2 === C
71+
@test N2 === N
72+
@test V2 * C2 A
73+
@test isisometric(V2)
74+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
75+
@test isisometric(N2)
76+
hV2 = collect(V2)
77+
hN2 = collect(N2)
78+
@test hV2 * hV2' + hN2 * hN2' I
79+
80+
atol = eps(real(T))
81+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol))
82+
AMDGPU.@allowscalar begin
83+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol))
84+
end
85+
@test V2 !== V
86+
@test C2 !== C
87+
@test N2 !== C
88+
@test V2 * C2 A
89+
@test isisometric(V2)
90+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
91+
@test isisometric(N2)
92+
hV2 = collect(V2)
93+
hN2 = collect(N2)
94+
@test hV2 * hV2' + hN2 * hN2' I
95+
96+
rtol = eps(real(T))
97+
for (trunc_orth, trunc_null) in (
98+
((; rtol = rtol), (; rtol = rtol)),
99+
(trunctol(; rtol), trunctol(; rtol, keep_below = true)),
100+
)
101+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth)
102+
AMDGPU.@allowscalar begin
103+
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null)
104+
end
105+
@test V2 !== V
106+
@test C2 !== C
107+
@test N2 !== C
108+
@test V2 * C2 A
109+
@test isisometric(V2)
110+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
111+
@test isisometric(N2)
112+
hV2 = collect(V2)
113+
hN2 = collect(N2)
114+
@test hV2 * hV2' + hN2 * hN2' I
115+
end
116+
117+
@testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
118+
m < n && kind == :polar && continue
119+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind)
120+
@test V2 === V
121+
@test C2 === C
122+
@test V2 * C2 A
123+
@test isisometric(V2)
124+
if kind != :polar
125+
N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind)
126+
@test N2 === N
127+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
128+
@test isisometric(N2)
129+
hV2 = collect(V2)
130+
hN2 = collect(N2)
131+
@test hV2 * hV2' + hN2 * hN2' I
132+
end
133+
134+
# with kind and tol kwargs
135+
if kind == :svd
136+
V2, C2 = @constinferred left_orth!(
137+
copy!(Ac, A), (V, C); kind = kind,
138+
trunc = (; atol = atol)
139+
)
140+
AMDGPU.@allowscalar begin
141+
N2 = @constinferred left_null!(
142+
copy!(Ac, A), N; kind = kind,
143+
trunc = (; atol = atol)
144+
)
145+
end
146+
@test V2 !== V
147+
@test C2 !== C
148+
@test N2 !== C
149+
@test V2 * C2 A
150+
@test V2' * V2 I
151+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
152+
@test isisometric(N2)
153+
hV2 = collect(V2)
154+
hN2 = collect(N2)
155+
@test hV2 * hV2' + hN2 * hN2' I
156+
157+
V2, C2 = @constinferred left_orth!(
158+
copy!(Ac, A), (V, C); kind = kind,
159+
trunc = (; rtol = rtol)
160+
)
161+
AMDGPU.@allowscalar begin
162+
N2 = @constinferred left_null!(
163+
copy!(Ac, A), N; kind = kind,
164+
trunc = (; rtol = rtol)
165+
)
166+
end
167+
@test V2 !== V
168+
@test C2 !== C
169+
@test N2 !== C
170+
@test V2 * C2 A
171+
@test isisometric(V2)
172+
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
173+
@test isisometric(N2)
174+
hV2 = collect(V2)
175+
hN2 = collect(N2)
176+
@test hV2 * hV2' + hN2 * hN2' I
177+
else
178+
@test_throws ArgumentError left_orth!(
179+
copy!(Ac, A), (V, C); kind = kind,
180+
trunc = (; atol = atol)
181+
)
182+
@test_throws ArgumentError left_orth!(
183+
copy!(Ac, A), (V, C); kind = kind,
184+
trunc = (; rtol = rtol)
185+
)
186+
@test_throws ArgumentError left_null!(
187+
copy!(Ac, A), N; kind = kind,
188+
trunc = (; atol = atol)
189+
)
190+
@test_throws ArgumentError left_null!(
191+
copy!(Ac, A), N; kind = kind,
192+
trunc = (; rtol = rtol)
193+
)
194+
end
195+
end
196+
end
197+
end
198+
199+
@testset "right_orth and right_null for T = $T" for T in (
200+
Float32, Float64, ComplexF32,
201+
ComplexF64,
202+
)
203+
rng = StableRNG(123)
204+
m = 54
205+
@testset for n in (37, m, 63)
206+
minmn = min(m, n)
207+
A = ROCArray(randn(rng, T, m, n))
208+
C, Vᴴ = @constinferred right_orth(A)
209+
Nᴴ = @constinferred right_null(A)
210+
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
211+
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
212+
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
213+
@test C * Vᴴ A
214+
@test isisometric(Vᴴ; side = :right)
215+
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) 0 atol = MatrixAlgebraKit.defaulttol(T)
216+
@test isisometric(Nᴴ; side = :right)
217+
hVᴴ = collect(Vᴴ)
218+
hNᴴ = collect(Nᴴ)
219+
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ I
220+
221+
M = LinearMap(A)
222+
CM, VMᴴ = @constinferred right_orth(M; kind = :svd)
223+
@test parent(CM) * parent(VMᴴ) A
224+
225+
Ac = similar(A)
226+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
227+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
228+
@test C2 === C
229+
@test Vᴴ2 === Vᴴ
230+
@test Nᴴ2 === Nᴴ
231+
@test C2 * Vᴴ2 A
232+
@test isisometric(Vᴴ2; side = :right)
233+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
234+
@test isisometric(Nᴴ; side = :right)
235+
hVᴴ2 = collect(Vᴴ2)
236+
hNᴴ2 = collect(Nᴴ2)
237+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
238+
239+
atol = eps(real(T))
240+
rtol = eps(real(T))
241+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol))
242+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol))
243+
@test C2 !== C
244+
@test Vᴴ2 !== Vᴴ
245+
@test Nᴴ2 !== Nᴴ
246+
@test C2 * Vᴴ2 A
247+
@test isisometric(Vᴴ2; side = :right)
248+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
249+
@test isisometric(Nᴴ; side = :right)
250+
hVᴴ2 = collect(Vᴴ2)
251+
hNᴴ2 = collect(Nᴴ2)
252+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
253+
254+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol))
255+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol))
256+
@test C2 !== C
257+
@test Vᴴ2 !== Vᴴ
258+
@test Nᴴ2 !== Nᴴ
259+
@test C2 * Vᴴ2 A
260+
@test isisometric(Vᴴ2; side = :right)
261+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
262+
@test isisometric(Nᴴ2; side = :right)
263+
hVᴴ2 = collect(Vᴴ2)
264+
hNᴴ2 = collect(Nᴴ2)
265+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
266+
267+
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
268+
n < m && kind == :polar && continue
269+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind)
270+
@test C2 === C
271+
@test Vᴴ2 === Vᴴ
272+
@test C2 * Vᴴ2 A
273+
@test isisometric(Vᴴ2; side = :right)
274+
if kind != :polar
275+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind)
276+
@test Nᴴ2 === Nᴴ
277+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
278+
@test isisometric(Nᴴ2; side = :right)
279+
hVᴴ2 = collect(Vᴴ2)
280+
hNᴴ2 = collect(Nᴴ2)
281+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
282+
end
283+
284+
if kind == :svd
285+
C2, Vᴴ2 = @constinferred right_orth!(
286+
copy!(Ac, A), (C, Vᴴ); kind = kind,
287+
trunc = (; atol = atol)
288+
)
289+
Nᴴ2 = @constinferred right_null!(
290+
copy!(Ac, A), Nᴴ; kind = kind,
291+
trunc = (; atol = atol)
292+
)
293+
@test C2 !== C
294+
@test Vᴴ2 !== Vᴴ
295+
@test Nᴴ2 !== Nᴴ
296+
@test C2 * Vᴴ2 A
297+
@test isisometric(Vᴴ2; side = :right)
298+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
299+
@test isisometric(Nᴴ2; side = :right)
300+
hVᴴ2 = collect(Vᴴ2)
301+
hNᴴ2 = collect(Nᴴ2)
302+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
303+
304+
C2, Vᴴ2 = @constinferred right_orth!(
305+
copy!(Ac, A), (C, Vᴴ); kind = kind,
306+
trunc = (; rtol = rtol)
307+
)
308+
Nᴴ2 = @constinferred right_null!(
309+
copy!(Ac, A), Nᴴ; kind = kind,
310+
trunc = (; rtol = rtol)
311+
)
312+
@test C2 !== C
313+
@test Vᴴ2 !== Vᴴ
314+
@test Nᴴ2 !== Nᴴ
315+
@test C2 * Vᴴ2 A
316+
@test isisometric(Vᴴ2; side = :right)
317+
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
318+
@test isisometric(Nᴴ2; side = :right)
319+
hVᴴ2 = collect(Vᴴ2)
320+
hNᴴ2 = collect(Nᴴ2)
321+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T)
322+
else
323+
@test_throws ArgumentError right_orth!(
324+
copy!(Ac, A), (C, Vᴴ); kind = kind,
325+
trunc = (; atol = atol)
326+
)
327+
@test_throws ArgumentError right_orth!(
328+
copy!(Ac, A), (C, Vᴴ); kind = kind,
329+
trunc = (; rtol = rtol)
330+
)
331+
@test_throws ArgumentError right_null!(
332+
copy!(Ac, A), Nᴴ; kind = kind,
333+
trunc = (; atol = atol)
334+
)
335+
@test_throws ArgumentError right_null!(
336+
copy!(Ac, A), Nᴴ; kind = kind,
337+
trunc = (; rtol = rtol)
338+
)
339+
end
340+
end
341+
end
342+
end

0 commit comments

Comments
 (0)