Skip to content

Commit f2822ca

Browse files
committed
Updates for TensorKit compatibility
1 parent 8f449d4 commit f2822ca

File tree

11 files changed

+43
-18
lines changed

11 files changed

+43
-18
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ Zygote = "0.7"
3838
julia = "1.10"
3939

4040
[extras]
41+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4142
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4243
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4345
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4446
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4547
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T
2828
return ROCSOLVER_DivideAndConquer(; kwargs...)
2929
end
3030

31+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = ishermitian(A)
32+
3133
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
3234
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
3335
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) =

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,49 @@ using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
11-
using CUDA
11+
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
1414
using LinearAlgebra: BlasFloat
1515

16+
using CUDA: i32
17+
1618
include("yacusolver.jl")
1719

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
20+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
1921
return CUSOLVER_HouseholderQR(; kwargs...)
2022
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
23+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2224
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2325
return LQViaTransposedQR(qr_alg)
2426
end
25-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
27+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2628
return CUSOLVER_QRIteration(; kwargs...)
2729
end
28-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
30+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2931
return CUSOLVER_Simple(; kwargs...)
3032
end
31-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
33+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
3234
return CUSOLVER_DivideAndConquer(; kwargs...)
3335
end
3436

37+
# include for block sector support
38+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
39+
return CUSOLVER_HouseholderQR(; kwargs...)
40+
end
41+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
42+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
43+
return LQViaTransposedQR(qr_alg)
44+
end
45+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
46+
return CUSOLVER_Jacobi(; kwargs...)
47+
end
48+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
49+
return CUSOLVER_Simple(; kwargs...)
50+
end
51+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
52+
return CUSOLVER_DivideAndConquer(; kwargs...)
53+
end
3554

3655
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
3756
YACUSOLVER.Xgeev!(A, D, V)

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in
3030
)
3131
chkstride1(A, U, Vᴴ, S)
3232
m, n = size(A)
33-
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
33+
(m < n) && throw(ArgumentError(lazy"CUSOLVER's gesvd requires m ($m) ≥ n ($n)"))
3434
minmn = min(m, n)
3535
if length(U) == 0
3636
jobu = 'N'
@@ -191,14 +191,17 @@ for (bname, fname, elty, relty) in
191191
(:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64),
192192
)
193193
@eval begin
194+
#! format: off
194195
function gesvdj!(
195196
A::StridedCuMatrix{$elty},
196197
S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)),
197198
U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
198199
Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
199200
tol::$relty = eps($relty),
200-
max_sweeps::Int = 100
201+
max_sweeps::Int = 100,
202+
kwargs...
201203
)
204+
#! format: on
202205
chkstride1(A, U, Vᴴ, S)
203206
m, n = size(A)
204207
minmn = min(m, n)

src/implementations/eig.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
149149
check_input(eig_vals!, A, D, alg)
150150
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
151151
if alg isa GPU_Simple
152-
isempty(alg.kwargs) ||
153-
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
152+
# TODO filter out nothing kwargs
153+
#isempty(alg.kwargs) ||
154+
# throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
154155
_gpu_geev!(A, D, V)
155156
end
156157
return D

src/implementations/eigh.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
5050
@check_scalar(V, A)
5151
return nothing
5252
end
53+
5354
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
5455
check_hermitian(A, alg)
5556
@assert isdiag(A)

src/implementations/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ end
418418
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
419419
_largest(x, y) = abs(x) < abs(y) ? y : x
420420

421-
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
421+
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
422422
check_input(svd_vals!, A, S, alg)
423423
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
424424
if alg isa GPU_QRIteration

test/amd/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
2525
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
2626
@test V * C A
2727
@test isisometric(V)
28-
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
28+
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
2929
@test isisometric(N)
3030
hV = collect(V)
3131
hN = collect(N)

test/amd/polar.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using AMDGPU
1313
k = min(m, n)
1414
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
1515
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16-
n < m && svd_alg isa ROCSOLVER_QRIteration && continue
1716
A = ROCArray(randn(rng, T, m, n))
1817
alg = PolarViaSVD(svd_alg)
1918
W, P = left_polar(A; alg)
@@ -23,6 +22,7 @@ using AMDGPU
2322
@test isisometric(W)
2423
# work around extremely strict Julia criteria for Hermiticity
2524
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
25+
@test isposdef(P)
2626

2727
Ac = similar(A)
2828
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg)
@@ -52,7 +52,6 @@ end
5252
k = min(m, n)
5353
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5454
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55-
n > m && svd_alg isa ROCSOLVER_QRIteration && continue
5655
A = ROCArray(randn(rng, T, m, n))
5756
alg = PolarViaSVD(svd_alg)
5857
P, Wᴴ = right_polar(A; alg)

test/cuda/polar.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using CUDA
1313
k = min(m, n)
1414
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
1515
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16-
n < m && svd_alg isa CUSOLVER_QRIteration && continue
1716
A = CuArray(randn(rng, T, m, n))
1817
alg = PolarViaSVD(svd_alg)
1918
W, P = left_polar(A; alg)
@@ -52,7 +51,6 @@ end
5251
k = min(m, n)
5352
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
5453
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55-
n > m && svd_alg isa CUSOLVER_QRIteration && continue
5654
A = CuArray(randn(rng, T, m, n))
5755
alg = PolarViaSVD(svd_alg)
5856
P, Wᴴ = right_polar(A; alg)

0 commit comments

Comments
 (0)