Skip to content

Commit e8ca7f1

Browse files
committed
Incremental progress
1 parent c2ecdf5 commit e8ca7f1

File tree

13 files changed

+136
-101
lines changed

13 files changed

+136
-101
lines changed

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
9+
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
1010
using CUDA
1111
using LinearAlgebra
1212
using LinearAlgebra: BlasFloat
@@ -30,6 +30,7 @@ _gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ)
3030
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C)
3131
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ)
3232
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
33+
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
3334
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3435

3536
end

src/implementations/eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function copy_input(::typeof(eig_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
1010

11-
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
11+
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1212
m, n = size(A)
1313
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
@@ -19,7 +19,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
1919
@check_scalar(V, A, complex)
2020
return nothing
2121
end
22-
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D)
22+
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2323
m, n = size(A)
2424
m == n || throw(DimensionMismatch("square input matrix expected"))
2525
@assert D isa AbstractVector
@@ -51,7 +51,7 @@ end
5151
# --------------
5252
# actual implementation
5353
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
54-
check_input(eig_full!, A, DV)
54+
check_input(eig_full!, A, DV, alg)
5555
D, V = DV
5656
if alg isa LAPACK_Simple
5757
isempty(alg.kwargs) ||
@@ -66,7 +66,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
6666
end
6767

6868
function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
69-
check_input(eig_vals!, A, D)
69+
check_input(eig_vals!, A, D, alg)
7070
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
7171
if alg isa LAPACK_Simple
7272
isempty(alg.kwargs) ||

src/implementations/eigh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
1010

11-
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
11+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1212
m, n = size(A)
1313
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
@@ -19,7 +19,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
1919
@check_scalar(V, A)
2020
return nothing
2121
end
22-
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D)
22+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2323
m, n = size(A)
2424
@assert D isa AbstractVector
2525
@check_size(D, (n,))
@@ -48,7 +48,7 @@ end
4848
# Implementation
4949
# --------------
5050
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
51-
check_input(eigh_full!, A, DV)
51+
check_input(eigh_full!, A, DV, alg)
5252
D, V = DV
5353
Dd = D.diag
5454
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
@@ -70,7 +70,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
7070
end
7171

7272
function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
73-
check_input(eigh_vals!, A, D)
73+
check_input(eigh_vals!, A, D, alg)
7474
V = similar(A, (size(A, 1), 0))
7575
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
7676
YALAPACK.heevr!(A, D, V; alg.kwargs...)

src/implementations/gen_eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix
77
return copy_input(gen_eig_full, A, B)
88
end
99

10-
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
10+
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
1111
ma, na = size(A)
1212
mb, nb = size(B)
1313
ma == na || throw(DimensionMismatch("square input matrix A expected"))
@@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
2424
@check_scalar(V, B, complex)
2525
return nothing
2626
end
27-
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
27+
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
2828
ma, na = size(A)
2929
mb, nb = size(B)
3030
ma == na || throw(DimensionMismatch("square input matrix A expected"))
@@ -57,7 +57,7 @@ end
5757
# --------------
5858
# actual implementation
5959
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
60-
check_input(gen_eig_full!, A, B, WV)
60+
check_input(gen_eig_full!, A, B, WV, alg)
6161
W, V = WV
6262
if alg isa LAPACK_Simple
6363
isempty(alg.kwargs) ||
@@ -72,7 +72,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig
7272
end
7373

7474
function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
75-
check_input(gen_eig_vals!, A, B, W)
75+
check_input(gen_eig_vals!, A, B, W, alg)
7676
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
7777
if alg isa LAPACK_Simple
7878
isempty(alg.kwargs) ||

src/implementations/lq.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function copy_input(::typeof(lq_null), A::AbstractMatrix)
1010
return copy!(similar(A, float(eltype(A))), A)
1111
end
1212

13-
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
13+
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
1414
m, n = size(A)
1515
L, Q = LQ
1616
@assert L isa AbstractMatrix && Q isa AbstractMatrix
@@ -20,7 +20,7 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
2020
@check_scalar(Q, A)
2121
return nothing
2222
end
23-
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
23+
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
2424
m, n = size(A)
2525
minmn = min(m, n)
2626
L, Q = LQ
@@ -31,7 +31,7 @@ function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
3131
@check_scalar(Q, A)
3232
return nothing
3333
end
34-
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ)
34+
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm)
3535
m, n = size(A)
3636
minmn = min(m, n)
3737
@assert Nᴴ isa AbstractMatrix
@@ -66,36 +66,36 @@ end
6666
# --------------
6767
# actual implementation
6868
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
69-
check_input(lq_full!, A, LQ)
69+
check_input(lq_full!, A, LQ, alg)
7070
L, Q = LQ
7171
_lapack_lq!(A, L, Q; alg.kwargs...)
7272
return L, Q
7373
end
7474
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
75-
check_input(lq_full!, A, LQ)
75+
check_input(lq_full!, A, LQ, alg)
7676
L, Q = LQ
7777
lq_via_qr!(A, L, Q, alg.qr_alg)
7878
return L, Q
7979
end
8080
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
81-
check_input(lq_compact!, A, LQ)
81+
check_input(lq_compact!, A, LQ, alg)
8282
L, Q = LQ
8383
_lapack_lq!(A, L, Q; alg.kwargs...)
8484
return L, Q
8585
end
8686
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
87-
check_input(lq_compact!, A, LQ)
87+
check_input(lq_compact!, A, LQ, alg)
8888
L, Q = LQ
8989
lq_via_qr!(A, L, Q, alg.qr_alg)
9090
return L, Q
9191
end
9292
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
93-
check_input(lq_null!, A, Nᴴ)
93+
check_input(lq_null!, A, Nᴴ, alg)
9494
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
9595
return Nᴴ
9696
end
9797
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
98-
check_input(lq_null!, A, Nᴴ)
98+
check_input(lq_null!, A, Nᴴ, alg)
9999
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
100100
return Nᴴ
101101
end

src/implementations/orthnull.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever nee
55
copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else
66
copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else
77

8-
function check_input(::typeof(left_orth!), A::AbstractMatrix, VC)
8+
function check_input(::typeof(left_orth!), A::AbstractMatrix, VC, ::AbstractAlgorithm)
99
m, n = size(A)
1010
minmn = min(m, n)
1111
V, C = VC
@@ -18,7 +18,7 @@ function check_input(::typeof(left_orth!), A::AbstractMatrix, VC)
1818
end
1919
return nothing
2020
end
21-
function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ)
21+
function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ, ::AbstractAlgorithm)
2222
m, n = size(A)
2323
minmn = min(m, n)
2424
C, Vᴴ = CVᴴ
@@ -32,15 +32,15 @@ function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ)
3232
return nothing
3333
end
3434

35-
function check_input(::typeof(left_null!), A::AbstractMatrix, N)
35+
function check_input(::typeof(left_null!), A::AbstractMatrix, N, ::AbstractAlgorithm)
3636
m, n = size(A)
3737
minmn = min(m, n)
3838
@assert N isa AbstractMatrix
3939
@check_size(N, (m, m - minmn))
4040
@check_scalar(N, A)
4141
return nothing
4242
end
43-
function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ)
43+
function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm)
4444
m, n = size(A)
4545
minmn = min(m, n)
4646
@assert Nᴴ isa AbstractMatrix
@@ -84,7 +84,6 @@ end
8484
function left_orth!(A, VC; trunc=nothing,
8585
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
8686
alg_polar=(;), alg_svd=(;))
87-
check_input(left_orth!, A, VC)
8887
if !isnothing(trunc) && kind != :svd
8988
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
9089
end
@@ -100,34 +99,40 @@ function left_orth!(A, VC; trunc=nothing,
10099
end
101100
function left_orth_qr!(A, VC, alg)
102101
alg′ = select_algorithm(qr_compact!, A, alg)
102+
check_input(left_orth!, A, VC, alg′)
103103
return qr_compact!(A, VC, alg′)
104104
end
105105
function left_orth_polar!(A, VC, alg)
106106
alg′ = select_algorithm(left_polar!, A, alg)
107+
check_input(left_orth!, A, VC, alg′)
107108
return left_polar!(A, VC, alg′)
108109
end
109110
function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing)
110111
alg′ = select_algorithm(svd_compact!, A, alg)
112+
check_input(left_orth!, A, VC, alg′)
111113
U, S, Vᴴ = svd_compact!(A, alg′)
112114
V, C = VC
113115
return copy!(V, U), mul!(C, S, Vᴴ)
114116
end
115117
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing)
116118
alg′ = select_algorithm(svd_compact!, A, alg)
119+
check_input(left_orth!, A, VC, alg′)
117120
V, C = VC
118121
S = Diagonal(initialize_output(svd_vals!, A, alg′))
119122
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′)
120123
return U, lmul!(S, Vᴴ)
121124
end
122125
function left_orth_svd!(A, VC, alg, trunc)
123126
alg′ = select_algorithm(svd_compact!, A, alg)
127+
check_input(left_orth!, A, VC, alg′)
124128
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
125129
U, S, Vᴴ = svd_trunc!(A, alg_trunc)
126130
V, C = VC
127131
return copy!(V, U), mul!(C, S, Vᴴ)
128132
end
129133
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc)
130134
alg′ = select_algorithm(svd_compact!, A, alg)
135+
check_input(left_orth!, A, VC, alg′)
131136
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
132137
V, C = VC
133138
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
@@ -138,7 +143,6 @@ end
138143
function right_orth!(A, CVᴴ; trunc=nothing,
139144
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
140145
alg_polar=(;), alg_svd=(;))
141-
check_input(right_orth!, A, CVᴴ)
142146
if !isnothing(trunc) && kind != :svd
143147
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
144148
end
@@ -154,34 +158,40 @@ function right_orth!(A, CVᴴ; trunc=nothing,
154158
end
155159
function right_orth_lq!(A, CVᴴ, alg)
156160
alg′ = select_algorithm(lq_compact!, A, alg)
161+
check_input(right_orth!, A, CVᴴ, alg′)
157162
return lq_compact!(A, CVᴴ, alg′)
158163
end
159164
function right_orth_polar!(A, CVᴴ, alg)
160165
alg′ = select_algorithm(right_polar!, A, alg)
166+
check_input(right_orth!, A, CVᴴ, alg′)
161167
return right_polar!(A, CVᴴ, alg′)
162168
end
163169
function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing)
164170
alg′ = select_algorithm(svd_compact!, A, alg)
171+
check_input(right_orth!, A, CVᴴ, alg′)
165172
U, S, Vᴴ′ = svd_compact!(A, alg′)
166173
C, Vᴴ = CVᴴ
167174
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
168175
end
169176
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing)
170177
alg′ = select_algorithm(svd_compact!, A, alg)
178+
check_input(right_orth!, A, CVᴴ, alg′)
171179
C, Vᴴ = CVᴴ
172180
S = Diagonal(initialize_output(svd_vals!, A, alg′))
173181
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′)
174182
return rmul!(U, S), Vᴴ
175183
end
176184
function right_orth_svd!(A, CVᴴ, alg, trunc)
177185
alg′ = select_algorithm(svd_compact!, A, alg)
186+
check_input(right_orth!, A, CVᴴ, alg′)
178187
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
179188
U, S, Vᴴ′ = svd_trunc!(A, alg_trunc)
180189
C, Vᴴ = CVᴴ
181190
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
182191
end
183192
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc)
184193
alg′ = select_algorithm(svd_compact!, A, alg)
194+
check_input(right_orth!, A, CVᴴ, alg′)
185195
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
186196
C, Vᴴ = CVᴴ
187197
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
@@ -204,7 +214,6 @@ end
204214
function left_null!(A, N; trunc=nothing,
205215
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
206216
alg_svd=(;))
207-
check_input(left_null!, A, N)
208217
if !isnothing(trunc) && kind != :svd
209218
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
210219
end
@@ -218,10 +227,12 @@ function left_null!(A, N; trunc=nothing,
218227
end
219228
function left_null_qr!(A, N, alg)
220229
alg′ = select_algorithm(qr_null!, A, alg)
230+
check_input(left_null!, A, N, alg′)
221231
return qr_null!(A, N, alg′)
222232
end
223233
function left_null_svd!(A, N, alg, trunc::Nothing=nothing)
224234
alg′ = select_algorithm(svd_full!, A, alg)
235+
check_input(left_null!, A, N, alg′)
225236
U, _, _ = svd_full!(A, alg′)
226237
(m, n) = size(A)
227238
return copy!(N, view(U, 1:m, (n + 1):m))
@@ -238,7 +249,6 @@ end
238249
function right_null!(A, Nᴴ; trunc=nothing,
239250
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
240251
alg_svd=(;))
241-
check_input(right_null!, A, Nᴴ)
242252
if !isnothing(trunc) && kind != :svd
243253
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
244254
end
@@ -252,16 +262,19 @@ function right_null!(A, Nᴴ; trunc=nothing,
252262
end
253263
function right_null_lq!(A, Nᴴ, alg)
254264
alg′ = select_algorithm(lq_null!, A, alg)
265+
check_input(right_null!, A, Nᴴ, alg′)
255266
return lq_null!(A, Nᴴ, alg′)
256267
end
257268
function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing)
258269
alg′ = select_algorithm(svd_full!, A, alg)
270+
check_input(right_null!, A, Nᴴ, alg′)
259271
_, _, Vᴴ = svd_full!(A, alg′)
260272
(m, n) = size(A)
261273
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
262274
end
263275
function right_null_svd!(A, Nᴴ, alg, trunc)
264276
alg′ = select_algorithm(svd_full!, A, alg)
277+
check_input(right_null!, A, Nᴴ, alg′)
265278
_, S, Vᴴ = svd_full!(A, alg′)
266279
trunc′ = trunc isa TruncationStrategy ? trunc :
267280
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :

0 commit comments

Comments
 (0)