Skip to content

Commit 871cbe6

Browse files
committed
Fix default_algorithm for CUDA matrices
1 parent 6526e58 commit 871cbe6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
89
using CUDA
910
using LinearAlgebra
1011
using LinearAlgebra: BlasFloat
@@ -13,15 +14,15 @@ include("yacusolver.jl")
1314
include("implementations/qr.jl")
1415
include("implementations/svd.jl")
1516

16-
function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
17+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
1718
return CUSOLVER_HouseholderQR(; kwargs...)
1819
end
19-
function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
20+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
2021
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2122
return LQViaTransposedQR(qr_alg)
2223
end
23-
function MatrixAlgebraKit.default_svd_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
24+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
2425
return CUSOLVER_QRIteration(; kwargs...)
2526
end
2627

27-
end
28+
end

0 commit comments

Comments
 (0)