Skip to content

Commit e46b7e5

Browse files
committed
Fixup amd tests
1 parent bb2feb0 commit e46b7e5

File tree

1 file changed

+92
-69
lines changed

1 file changed

+92
-69
lines changed

test/amd/orthnull.jl

Lines changed: 92 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I, mul!, norm
6-
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
5+
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
76
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
87
initialize_output, AbstractAlgorithm
98
using AMDGPU
@@ -56,7 +55,7 @@ end
5655
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
5756
rng = StableRNG(123)
5857
m = 54
59-
for n in (37, m, 63)
58+
@testset for n in (37, m, 63)
6059
minmn = min(m, n)
6160
A = ROCArray(randn(rng, T, m, n))
6261
V, C = @constinferred left_orth(A)
@@ -66,7 +65,7 @@ end
6665
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
6766
@test V * C A
6867
@test isisometric(V)
69-
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
68+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
7069
@test isisometric(N)
7170
hV = collect(V)
7271
hN = collect(N)
@@ -79,15 +78,16 @@ end
7978
if m > n
8079
nullity = 5
8180
V, C = @constinferred left_orth(A)
82-
# doesn't work because of truncation
83-
#N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
81+
AMDGPU.@allowscalar begin
82+
N = @constinferred left_null(A; trunc = (; maxnullity = nullity))
83+
end
8484
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
8585
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
86-
#@test N isa ROCMatrix{T} && size(N) == (m, nullity)
86+
@test N isa ROCMatrix{T} && size(N) == (m, nullity)
8787
@test V * C ≈ A
8888
@test isisometric(V)
89-
#@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
90-
#@test isisometric(N)
89+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
90+
@test isisometric(N)
9191
end
9292

9393
for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR())
@@ -98,7 +98,7 @@ end
9898
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
9999
@test V * C A
100100
@test isisometric(V)
101-
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
101+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
102102
@test isisometric(N)
103103
hV = collect(V)
104104
hN = collect(N)
@@ -120,35 +120,39 @@ end
120120
@test hV2 * hV2' + hN2 * hN2' I
121121

122122
atol = eps(real(T))
123-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
123+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol))
124124
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol))
125-
#@test V2 !== V
126-
#@test C2 !== C
125+
@test V2 !== V
126+
@test C2 !== C
127127
@test N2 !== C
128-
#@test V2 * C2 ≈ A
129-
#@test isisometric(V2)
128+
@test V2 * C2 A
129+
@test isisometric(V2)
130130
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
131131
@test isisometric(N2)
132-
#@test V2 * V2' + N2 * N2' ≈ I
132+
hV2 = collect(V2)
133+
hN2 = collect(N2)
134+
@test hV2 * hV2' + hN2 * hN2' ≈ I
133135
134136
rtol = eps(real(T))
135137
for (trunc_orth, trunc_null) in (
136138
((; rtol = rtol), (; rtol = rtol)),
137-
(TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)),
139+
(trunctol(; rtol), trunctol(; rtol, keep_below = true)),
138140
)
139-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
141+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth)
140142
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null)
141-
#@test V2 !== V
142-
#@test C2 !== C
143+
@test V2 !== V
144+
@test C2 !== C
143145
@test N2 !== C
144-
#@test V2 * C2 ≈ A
145-
#@test isisometric(V2)
146+
@test V2 * C2 ≈ A
147+
@test isisometric(V2)
146148
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
147149
@test isisometric(N2)
148-
#@test V2 * V2' + N2 * N2' ≈ I
150+
hV2 = collect(V2)
151+
hN2 = collect(N2)
152+
@test hV2 * hV2' + hN2 * hN2' I
149153
end
150154

151-
for kind in (:qr, :polar, :svd) # explicit kind kwarg
155+
@testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
152156
m < n && kind == :polar && continue
153157
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind)
154158
@test V2 === V
@@ -167,35 +171,43 @@ end
167171
168172
# with kind and tol kwargs
169173
if kind == :svd
170-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
171-
# trunc=(; atol=atol))
174+
V2, C2 = @constinferred left_orth!(
175+
copy!(Ac, A), (V, C); kind = kind,
176+
trunc = (; atol = atol)
177+
)
172178
N2 = @constinferred left_null!(
173179
copy!(Ac, A), N; kind = kind,
174180
trunc = (; atol = atol)
175181
)
176-
#@test V2 !== V
177-
#@test C2 !== C
182+
@test V2 !== V
183+
@test C2 !== C
178184
@test N2 !== C
179-
#@test V2 * C2 ≈ A
180-
#@test V2' * V2 I
185+
@test V2 * C2 ≈ A
186+
@test V2' * V2 I
181187
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
182188
@test isisometric(N2)
183-
#@test V2 * V2' + N2 * N2' ≈ I
189+
hV2 = collect(V2)
190+
hN2 = collect(N2)
191+
@test hV2 * hV2' + hN2 * hN2' ≈ I
184192
185-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
186-
# trunc=(; rtol=rtol))
193+
V2, C2 = @constinferred left_orth!(
194+
copy!(Ac, A), (V, C); kind = kind,
195+
trunc = (; rtol = rtol)
196+
)
187197
N2 = @constinferred left_null!(
188198
copy!(Ac, A), N; kind = kind,
189199
trunc = (; rtol = rtol)
190200
)
191-
#@test V2 !== V
192-
#@test C2 !== C
201+
@test V2 !== V
202+
@test C2 !== C
193203
@test N2 !== C
194-
#@test V2 * C2 ≈ A
195-
#@test isisometric(V2)
204+
@test V2 * C2 ≈ A
205+
@test isisometric(V2)
196206
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
197207
@test isisometric(N2)
198-
#@test V2 * V2' + N2 * N2' ≈ I
208+
hV2 = collect(V2)
209+
hN2 = collect(N2)
210+
@test hV2 * hV2' + hN2 * hN2' I
199211
else
200212
@test_throws ArgumentError left_orth!(
201213
copy!(Ac, A), (V, C); kind = kind,
@@ -258,39 +270,40 @@ end
258270
hNᴴ2 = collect(Nᴴ2)
259271
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
260272

261-
# TODO truncate currently broken due to searchsortedlast
262273
atol = eps(real(T))
263274
rtol = eps(real(T))
264-
#=C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
265-
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol))
275+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol))
276+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol))
266277
@test C2 !== C
267278
@test Vᴴ2 !== Vᴴ
268279
@test Nᴴ2 !== Nᴴ
269280
@test C2 * Vᴴ2 A
270-
@test isisometric(Vᴴ2; side=:right)
281+
@test isisometric(Vᴴ2; side = :right)
271282
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
272-
@test isisometric(Nᴴ; side=:right)
273-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
283+
@test isisometric(Nᴴ; side = :right)
284+
hVᴴ2 = collect(Vᴴ2)
285+
hNᴴ2 = collect(Nᴴ2)
286+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
274287

275-
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
276-
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol))
288+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol))
289+
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol))
277290
@test C2 !== C
278291
@test Vᴴ2 !== Vᴴ
279292
@test Nᴴ2 !== Nᴴ
280293
@test C2 * Vᴴ2 A
281-
@test isisometric(Vᴴ2; side=:right)
294+
@test isisometric(Vᴴ2; side = :right)
282295
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
283-
@test isisometric(Nᴴ2; side=:right)
284-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
285-
=#
296+
@test isisometric(Nᴴ2; side = :right)
297+
hVᴴ2 = collect(Vᴴ2)
298+
hNᴴ2 = collect(Nᴴ2)
299+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
286300

287301
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
288302
n < m && kind == :polar && continue
289303
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind)
290304
@test C2 === C
291305
@test Vᴴ2 === Vᴴ
292-
A2 = C2 * Vᴴ2
293-
@test A2 A
306+
@test C2 * Vᴴ2 A
294307
@test isisometric(Vᴴ2; side = :right)
295308
if kind != :polar
296309
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind)
@@ -303,33 +316,43 @@ end
303316
end
304317

305318
if kind == :svd
306-
# doesn't work yet because of searchsortedfirst
307-
#= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
308-
trunc=(; atol=atol))
309-
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
310-
trunc=(; atol=atol))
319+
C2, Vᴴ2 = @constinferred right_orth!(
320+
copy!(Ac, A), (C, Vᴴ); kind = kind,
321+
trunc = (; atol = atol)
322+
)
323+
Nᴴ2 = @constinferred right_null!(
324+
copy!(Ac, A), Nᴴ; kind = kind,
325+
trunc = (; atol = atol)
326+
)
311327
@test C2 !== C
312328
@test Vᴴ2 !== Vᴴ
313329
@test Nᴴ2 !== Nᴴ
314330
@test C2 * Vᴴ2 A
315-
@test isisometric(Vᴴ2; side=:right)
331+
@test isisometric(Vᴴ2; side = :right)
316332
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
317-
@test isisometric(Nᴴ2; side=:right)
318-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
319-
320-
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
321-
trunc=(; rtol=rtol))
322-
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
323-
trunc=(; rtol=rtol))
333+
@test isisometric(Nᴴ2; side = :right)
334+
hVᴴ2 = collect(Vᴴ2)
335+
hNᴴ2 = collect(Nᴴ2)
336+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
337+
338+
C2, Vᴴ2 = @constinferred right_orth!(
339+
copy!(Ac, A), (C, Vᴴ); kind = kind,
340+
trunc = (; rtol = rtol)
341+
)
342+
Nᴴ2 = @constinferred right_null!(
343+
copy!(Ac, A), Nᴴ; kind = kind,
344+
trunc = (; rtol = rtol)
345+
)
324346
@test C2 !== C
325347
@test Vᴴ2 !== Vᴴ
326348
@test Nᴴ2 !== Nᴴ
327349
@test C2 * Vᴴ2 A
328-
@test isisometric(Vᴴ2; side=:right)
350+
@test isisometric(Vᴴ2; side = :right)
329351
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
330-
@test isisometric(Nᴴ2; side=:right)
331-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
332-
=#
352+
@test isisometric(Nᴴ2; side = :right)
353+
hVᴴ2 = collect(Vᴴ2)
354+
hNᴴ2 = collect(Nᴴ2)
355+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T)
333356
else
334357
@test_throws ArgumentError right_orth!(
335358
copy!(Ac, A), (C, Vᴴ); kind = kind,

0 commit comments

Comments
 (0)