Skip to content

Commit 58846bb

Browse files
kshyattKatharine Hyatt
andauthored
Updates for TensorKit compatibility (#49)
* Add tests for image and null space for GPU * Only use the scalar method for AMDGPU * Updates for TensorKit compatibility * Fix AMDGPU duplication * Fix AMD polar test * Comments --------- Co-authored-by: Katharine Hyatt <[email protected]>
1 parent c99084b commit 58846bb

File tree

9 files changed

+35
-16
lines changed

9 files changed

+35
-16
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ Zygote = "0.7"
3838
julia = "1.10"
3939

4040
[extras]
41+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4142
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4243
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4345
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4446
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4547
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,47 @@ using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
11-
using CUDA
11+
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
1414
using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
18+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
1919
return CUSOLVER_HouseholderQR(; kwargs...)
2020
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
21+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2222
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2323
return LQViaTransposedQR(qr_alg)
2424
end
25-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
25+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2626
return CUSOLVER_QRIteration(; kwargs...)
2727
end
28-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
28+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2929
return CUSOLVER_Simple(; kwargs...)
3030
end
31-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
31+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
3232
return CUSOLVER_DivideAndConquer(; kwargs...)
3333
end
3434

35+
# include for block sector support
36+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
37+
return CUSOLVER_HouseholderQR(; kwargs...)
38+
end
39+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
40+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
41+
return LQViaTransposedQR(qr_alg)
42+
end
43+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
44+
return CUSOLVER_Jacobi(; kwargs...)
45+
end
46+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
47+
return CUSOLVER_Simple(; kwargs...)
48+
end
49+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
50+
return CUSOLVER_DivideAndConquer(; kwargs...)
51+
end
3552

3653
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
3754
YACUSOLVER.Xgeev!(A, D, V)

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in
3030
)
3131
chkstride1(A, U, Vᴴ, S)
3232
m, n = size(A)
33-
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
33+
(m < n) && throw(ArgumentError(lazy"CUSOLVER's gesvd requires m ($m) ≥ n ($n)"))
3434
minmn = min(m, n)
3535
if length(U) == 0
3636
jobu = 'N'
@@ -191,14 +191,17 @@ for (bname, fname, elty, relty) in
191191
(:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64),
192192
)
193193
@eval begin
194+
#! format: off
194195
function gesvdj!(
195196
A::StridedCuMatrix{$elty},
196197
S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)),
197198
U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
198199
Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
199200
tol::$relty = eps($relty),
200-
max_sweeps::Int = 100
201+
max_sweeps::Int = 100,
202+
kwargs...
201203
)
204+
#! format: on
202205
chkstride1(A, U, Vᴴ, S)
203206
m, n = size(A)
204207
minmn = min(m, n)

src/implementations/eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
137137
D, V = DV
138138
if alg isa GPU_Simple
139139
isempty(alg.kwargs) ||
140-
throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
140+
@warn "GPU_Simple (geev) does not accept any keyword arguments"
141141
_gpu_geev!(A, D.diag, V)
142142
end
143143
# TODO: make this controllable using a `gaugefix` keyword argument
@@ -150,7 +150,7 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
150150
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
151151
if alg isa GPU_Simple
152152
isempty(alg.kwargs) ||
153-
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
153+
@warn "GPU_Simple (geev) does not accept any keyword arguments"
154154
_gpu_geev!(A, D, V)
155155
end
156156
return D

src/implementations/eigh.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
5050
@check_scalar(V, A)
5151
return nothing
5252
end
53+
5354
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
5455
check_hermitian(A, alg)
5556
@assert isdiag(A)

src/implementations/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ end
418418
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
419419
_largest(x, y) = abs(x) < abs(y) ? y : x
420420

421-
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
421+
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
422422
check_input(svd_vals!, A, S, alg)
423423
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
424424
if alg isa GPU_QRIteration

test/amd/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
2525
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
2626
@test V * C A
2727
@test isisometric(V)
28-
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
28+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
2929
@test isisometric(N)
3030
hV = collect(V)
3131
hN = collect(N)

test/amd/polar.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using AMDGPU
1313
k = min(m, n)
1414
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
1515
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16-
n < m && svd_alg isa ROCSOLVER_QRIteration && continue
1716
A = ROCArray(randn(rng, T, m, n))
1817
alg = PolarViaSVD(svd_alg)
1918
W, P = left_polar(A; alg)
@@ -52,7 +51,6 @@ end
5251
k = min(m, n)
5352
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5453
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55-
n > m && svd_alg isa ROCSOLVER_QRIteration && continue
5654
A = ROCArray(randn(rng, T, m, n))
5755
alg = PolarViaSVD(svd_alg)
5856
P, Wᴴ = right_polar(A; alg)

test/cuda/polar.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using CUDA
1313
k = min(m, n)
1414
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
1515
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16-
n < m && svd_alg isa CUSOLVER_QRIteration && continue
1716
A = CuArray(randn(rng, T, m, n))
1817
alg = PolarViaSVD(svd_alg)
1918
W, P = left_polar(A; alg)
@@ -52,7 +51,6 @@ end
5251
k = min(m, n)
5352
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
5453
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55-
n > m && svd_alg isa CUSOLVER_QRIteration && continue
5654
A = CuArray(randn(rng, T, m, n))
5755
alg = PolarViaSVD(svd_alg)
5856
P, Wᴴ = right_polar(A; alg)

0 commit comments

Comments
 (0)