Skip to content

Commit 64954ef

Browse files
author
Katharine Hyatt
committed
Updates for AMD
1 parent 3ea43a3 commit 64954ef

File tree

3 files changed

+75
-58
lines changed

3 files changed

+75
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Aqua = "0.6, 0.7, 0.8"
2222
ChainRulesCore = "1"
2323
ChainRulesTestUtils = "1"
2424
CUDA = "5"
25-
JET = "0.9"
25+
JET = "0.10, 0.9"
2626
LinearAlgebra = "1"
2727
SafeTestsets = "0.1"
2828
StableRNGs = "1"

test/amd/orthnull.jl

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I, mul!, norm
6-
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
5+
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
76
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
87
initialize_output, AbstractAlgorithm
9-
using AMDGPU
8+
using AMDGPU
109

1110
# Used to test non-AbstractMatrix codepaths.
1211
struct LinearMap{P<:AbstractMatrix}
@@ -54,7 +53,7 @@ end
5453
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
5554
rng = StableRNG(123)
5655
m = 54
57-
for n in (37, m, 63)
56+
@testset for n in (37, m, 63)
5857
minmn = min(m, n)
5958
A = ROCArray(randn(rng, T, m, n))
6059
V, C = @constinferred left_orth(A)
@@ -64,7 +63,7 @@ end
6463
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
6564
@test V * C A
6665
@test isisometric(V)
67-
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
66+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
6867
@test isisometric(N)
6968
hV = collect(V)
7069
hN = collect(N)
@@ -77,15 +76,16 @@ end
7776
if m > n
7877
nullity = 5
7978
V, C = @constinferred left_orth(A)
80-
# doesn't work because of truncation
81-
#N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
79+
AMDGPU.@allowscalar begin
80+
N = @constinferred left_null(A; trunc=(; maxnullity=nullity))
81+
end
8282
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
8383
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
84-
#@test N isa ROCMatrix{T} && size(N) == (m, nullity)
84+
@test N isa ROCMatrix{T} && size(N) == (m, nullity)
8585
@test V * C ≈ A
8686
@test isisometric(V)
87-
#@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
88-
#@test isisometric(N)
87+
@test LinearAlgebra.norm(A' * N) 0 atol = MatrixAlgebraKit.defaulttol(T)
88+
@test isisometric(N)
8989
end
9090

9191
for alg_qr in ((; positive=true), (; positive=false), ROCSOLVER_HouseholderQR())
@@ -96,7 +96,7 @@ end
9696
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
9797
@test V * C A
9898
@test isisometric(V)
99-
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
99+
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
100100
@test isisometric(N)
101101
hV = collect(V)
102102
hN = collect(N)
@@ -118,33 +118,39 @@ end
118118
@test hV2 * hV2' + hN2 * hN2' I
119119

120120
atol = eps(real(T))
121-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
121+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol))
122122
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; atol=atol))
123-
#@test V2 !== V
124-
#@test C2 !== C
123+
@test V2 !== V
124+
@test C2 !== C
125125
@test N2 !== C
126-
#@test V2 * C2 ≈ A
127-
#@test isisometric(V2)
126+
@test V2 * C2 A
127+
@test isisometric(V2)
128128
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
129129
@test isisometric(N2)
130-
#@test V2 * V2' + N2 * N2' ≈ I
130+
hV2 = collect(V2)
131+
hN2 = collect(N2)
132+
@test hV2 * hV2' + hN2 * hN2' ≈ I
131133
132134
rtol = eps(real(T))
133-
for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)),
134-
(TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol)))
135-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
135+
for (trunc_orth, trunc_null) in (
136+
((; rtol = rtol), (; rtol = rtol)),
137+
(trunctol(; rtol), trunctol(; rtol, keep_below = true)),
138+
)
139+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
136140
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null)
137-
#@test V2 !== V
138-
#@test C2 !== C
141+
@test V2 !== V
142+
@test C2 !== C
139143
@test N2 !== C
140-
#@test V2 * C2 ≈ A
141-
#@test isisometric(V2)
144+
@test V2 * C2 ≈ A
145+
@test isisometric(V2)
142146
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
143147
@test isisometric(N2)
144-
#@test V2 * V2' + N2 * N2' ≈ I
148+
hV2 = collect(V2)
149+
hN2 = collect(N2)
150+
@test hV2 * hV2' + hN2 * hN2' I
145151
end
146152

147-
for kind in (:qr, :polar, :svd) # explicit kind kwarg
153+
@testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
148154
m < n && kind == :polar && continue
149155
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind)
150156
@test V2 === V
@@ -163,31 +169,35 @@ end
163169
164170
# with kind and tol kwargs
165171
if kind == :svd
166-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
167-
# trunc=(; atol=atol))
172+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
173+
trunc=(; atol=atol))
168174
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
169175
trunc=(; atol=atol))
170-
#@test V2 !== V
171-
#@test C2 !== C
176+
@test V2 !== V
177+
@test C2 !== C
172178
@test N2 !== C
173-
#@test V2 * C2 ≈ A
174-
#@test V2' * V2 I
179+
@test V2 * C2 ≈ A
180+
@test V2' * V2 I
175181
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
176182
@test isisometric(N2)
177-
#@test V2 * V2' + N2 * N2' ≈ I
183+
hV2 = collect(V2)
184+
hN2 = collect(N2)
185+
@test hV2 * hV2' + hN2 * hN2' ≈ I
178186
179-
#V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
180-
# trunc=(; rtol=rtol))
187+
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind,
188+
trunc=(; rtol=rtol))
181189
N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind,
182190
trunc=(; rtol=rtol))
183-
#@test V2 !== V
184-
#@test C2 !== C
191+
@test V2 !== V
192+
@test C2 !== C
185193
@test N2 !== C
186-
#@test V2 * C2 ≈ A
187-
#@test isisometric(V2)
194+
@test V2 * C2 ≈ A
195+
@test isisometric(V2)
188196
@test LinearAlgebra.norm(A' * N2) 0 atol = MatrixAlgebraKit.defaulttol(T)
189197
@test isisometric(N2)
190-
#@test V2 * V2' + N2 * N2' ≈ I
198+
hV2 = collect(V2)
199+
hN2 = collect(N2)
200+
@test hV2 * hV2' + hN2 * hN2' I
191201
else
192202
@test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind,
193203
trunc=(; atol=atol))
@@ -240,10 +250,9 @@ end
240250
hNᴴ2 = collect(Nᴴ2)
241251
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
242252

243-
# TODO truncate currently broken due to searchsortedlast
244253
atol = eps(real(T))
245254
rtol = eps(real(T))
246-
#=C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
255+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
247256
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol))
248257
@test C2 !== C
249258
@test Vᴴ2 !== Vᴴ
@@ -252,7 +261,9 @@ end
252261
@test isisometric(Vᴴ2; side=:right)
253262
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
254263
@test isisometric(Nᴴ; side=:right)
255-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
264+
hVᴴ2 = collect(Vᴴ2)
265+
hNᴴ2 = collect(Nᴴ2)
266+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
256267

257268
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
258269
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol))
@@ -263,16 +274,16 @@ end
263274
@test isisometric(Vᴴ2; side=:right)
264275
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
265276
@test isisometric(Nᴴ2; side=:right)
266-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
267-
=#
277+
hVᴴ2 = collect(Vᴴ2)
278+
hNᴴ2 = collect(Nᴴ2)
279+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
268280

269281
@testset "kind = $kind" for kind in (:lq, :polar, :svd)
270282
n < m && kind == :polar && continue
271283
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind)
272284
@test C2 === C
273285
@test Vᴴ2 === Vᴴ
274-
A2 = C2 * Vᴴ2
275-
@test A2 A
286+
@test C2 * Vᴴ2 A
276287
@test isisometric(Vᴴ2; side=:right)
277288
if kind != :polar
278289
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind)
@@ -285,8 +296,7 @@ end
285296
end
286297

287298
if kind == :svd
288-
# doesn't work yet because of searchsortedfirst
289-
#= C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
299+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
290300
trunc=(; atol=atol))
291301
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind,
292302
trunc=(; atol=atol))
@@ -297,7 +307,9 @@ end
297307
@test isisometric(Vᴴ2; side=:right)
298308
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
299309
@test isisometric(Nᴴ2; side=:right)
300-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
310+
hVᴴ2 = collect(Vᴴ2)
311+
hNᴴ2 = collect(Nᴴ2)
312+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 I
301313

302314
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
303315
trunc=(; rtol=rtol))
@@ -310,8 +322,9 @@ end
310322
@test isisometric(Vᴴ2; side=:right)
311323
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) 0 atol = MatrixAlgebraKit.defaulttol(T)
312324
@test isisometric(Nᴴ2; side=:right)
313-
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I
314-
=#
325+
hVᴴ2 = collect(Vᴴ2)
326+
hNᴴ2 = collect(Nᴴ2)
327+
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 diagm(ones(T, size(Vᴴ2, 2))) atol = m*n*MatrixAlgebraKit.defaulttol(T)
315328
else
316329
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
317330
trunc=(; atol=atol))

test/amd/polar.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I, isposdef
5+
using LinearAlgebra: LinearAlgebra, I, isposdef, Hermitian
66
using MatrixAlgebraKit: PolarViaSVD
77
using AMDGPU
88

@@ -21,15 +21,17 @@ using AMDGPU
2121
@test P isa ROCMatrix{T} && size(P) == (n, n)
2222
@test W * P A
2323
@test isisometric(W)
24-
@test isposdef(P)
24+
# work around extremely strict Julia criteria for Hermiticity
25+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
2526

2627
Ac = similar(A)
2728
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg)
2829
@test W2 === W
2930
@test P2 === P
3031
@test W * P A
3132
@test isisometric(W)
32-
@test isposdef(P)
33+
# work around extremely strict Julia criteria for Hermiticity
34+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
3335

3436
noP = similar(P, (0, 0))
3537
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg)
@@ -58,15 +60,17 @@ end
5860
@test P isa ROCMatrix{T} && size(P) == (m, m)
5961
@test P * Wᴴ ≈ A
6062
@test isisometric(Wᴴ; side=:right)
61-
@test isposdef(P)
63+
# work around extremely strict Julia criteria for Hermiticity
64+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
6265
6366
Ac = similar(A)
6467
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg)
6568
@test P2 === P
6669
@test Wᴴ2 === Wᴴ
6770
@test P * Wᴴ ≈ A
6871
@test isisometric(Wᴴ; side=:right)
69-
@test isposdef(P)
72+
# work around extremely strict Julia criteria for Hermiticity
73+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
7074
7175
noP = similar(P, (0, 0))
7276
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)

0 commit comments

Comments
 (0)