diff --git a/Project.toml b/Project.toml index 53bac9a5..507d81df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -36,4 +36,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", + "ChainRulesTestUtils", "StableRNGs", "Zygote"] diff --git a/src/algorithms.jl b/src/algorithms.jl index 4a6bfe58..f559b42e 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -61,6 +61,7 @@ end MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...)) Decide on an algorithm to use for implementing the function `f` on inputs of type `A`. +This can be obtained both for values `A` or types `A`. If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is. @@ -73,44 +74,46 @@ automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and the keyword arguments in `kwargs` will be passed to the algorithm constructor. Finally, the same behavior is obtained when the keyword arguments are passed as the third positional argument in the form of a `NamedTuple`. -""" -function select_algorithm end +""" select_algorithm function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} - return _select_algorithm(f, A, alg; kwargs...) + return select_algorithm(f, typeof(A), alg; kwargs...) end +function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg} + if isnothing(alg) + return default_algorithm(f, A; kwargs...) + elseif alg isa Symbol + return Algorithm{alg}(; kwargs...) + elseif alg isa Type + return alg(; kwargs...) + elseif alg isa NamedTuple + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return default_algorithm(f, A; alg...) + elseif alg isa AbstractAlgorithm + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return alg + end -function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F} - return default_algorithm(f, A; kwargs...) -end -function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F} - return Algorithm{alg}(; kwargs...) -end -function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg} - return Alg(; kwargs...) -end -function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) - return default_algorithm(f, A; alg...) -end -function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified.")) - return alg -end -function _select_algorithm(f::F, A, alg; kwargs...) where {F} - return throw(ArgumentError("Unknown alg $alg")) + throw(ArgumentError("Unknown alg $alg")) end + @doc """ MatrixAlgebraKit.default_algorithm(f, A; kwargs...) + MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA} Select the default algorithm for a given factorization function `f` and input `A`. In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified explicitly. -""" -function default_algorithm end +New types should prefer to register their default algorithms in the type domain. +""" default_algorithm +default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) +# avoid infinite recursion: +function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T} + throw(MethodError(default_algorithm, (f, T))) +end @doc """ copy_input(f, A) @@ -118,8 +121,7 @@ function default_algorithm end Preprocess the input `A` for a given function, such that it may be handled correctly later. This may include a copy whenever the implementation would destroy the original matrix, or a change of element type to something that is supported. -""" -function copy_input end +""" copy_input @doc """ initialize_output(f, A, alg) @@ -127,8 +129,7 @@ function copy_input end Whenever possible, allocate the destination for applying a given algorithm in-place. If this is not possible, for example when the output size is not known a priori or immutable, this function may return `nothing`. -""" -function initialize_output end +""" initialize_output # Utility macros # -------------- @@ -176,25 +177,35 @@ macro functiondef(f) f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`")) f! = Symbol(f, :!) - return esc(quote - # out of place to inplace - $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) - $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) - - # fill in arguments - function $f!(A; alg=nothing, kwargs...) - return $f!(A, select_algorithm($f!, A, alg; kwargs...)) - end - function $f!(A, out; alg=nothing, kwargs...) - return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) - end - function $f!(A, alg::AbstractAlgorithm) - return $f!(A, initialize_output($f!, A, alg), alg) - end - - # copy documentation to both functions - Core.@__doc__ $f, $f! - end) + ex = quote + # out of place to inplace + $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) + $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) + + # fill in arguments + function $f!(A; alg=nothing, kwargs...) + return $f!(A, select_algorithm($f!, A, alg; kwargs...)) + end + function $f!(A, out; alg=nothing, kwargs...) + return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) + end + function $f!(A, alg::AbstractAlgorithm) + return $f!(A, initialize_output($f!, A, alg), alg) + end + + # define fallbacks for algorithm selection + @inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg; + kwargs...) where {Alg,A} + return select_algorithm($f!, A, alg; kwargs...) + end + @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_algorithm($f!, A; kwargs...) + end + + # copy documentation to both functions + Core.@__doc__ $f, $f! + end + return esc(ex) end """ diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 9071a657..77aa0672 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -87,27 +87,20 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -for f in (:eig_full, :eig_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end - end +default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) +default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) +function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_Expert(; kwargs...) end -function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...) - return select_algorithm(eig_trunc!, A, alg; kwargs...) +for f in (:eig_full!, :eig_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) + end end -function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) + +function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:YALAPACK.BlasMat} alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end - -# Default to LAPACK -function default_eig_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_Expert(; kwargs...) -end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index b092795c..3ed38789 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -86,27 +86,22 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc) # Algorithm selection # ------------------- -for f in (:eigh_full, :eigh_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eigh_algorithm(A; kwargs...) - end - end +default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs...) +function default_eigh_algorithm(T::Type; kwargs...) + throw(MethodError(default_eigh_algorithm, (T,))) +end +function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end -function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...) - return select_algorithm(eigh_trunc!, A, alg; kwargs...) +for f in (:eigh_full!, :eigh_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eigh_algorithm(A; kwargs...) + end end -function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) + +function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:YALAPACK.BlasMat} alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end - -# Default to LAPACK -function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) -end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index e98223f1..9de85bae 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -68,19 +68,18 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- -for f in (:lq_full, :lq_compact, :lq_null) - f! = Symbol(f, :!) +default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...) +function default_lq_algorithm(T::Type; kwargs...) + throw(MethodError(default_lq_algorithm, (T,))) +end +function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_HouseholderLQ(; kwargs...) +end + +for f in (:lq_full!, :lq_compact!, :lq_null!) @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_lq_algorithm(A; kwargs...) end end end - -# Default to LAPACK -function default_lq_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_HouseholderLQ(; kwargs...) -end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index b209a327..111a6e3a 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -60,19 +60,16 @@ end # Algorithm selection # ------------------- -for f in (:left_polar, :right_polar) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_polar_algorithm(A; kwargs...) - end - end +default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...) +function default_polar_algorithm(T::Type; kwargs...) + throw(MethodError(default_polar_algorithm, (T,))) +end +function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...)) end -# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` -function default_polar_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return PolarViaSVD(default_svd_algorithm(A; kwargs...)) +for f in (:left_polar!, :right_polar!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_polar_algorithm(A; kwargs...) + end end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index cbded32d..a1c12b7a 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -68,19 +68,19 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- -for f in (:qr_full, :qr_compact, :qr_null) - f! = Symbol(f, :!) +default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...) +function default_qr_algorithm(T::Type; kwargs...) + throw(MethodError(default_qr_algorithm, (T,))) +end +function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_HouseholderQR(; kwargs...) +end + +for f in (:qr_full!, :qr_compact!, :qr_null!) @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) + function default_algorithm(::typeof($f), ::Type{A}; + kwargs...) where {A<:YALAPACK.BlasMat} return default_qr_algorithm(A; kwargs...) end end end - -# Default to LAPACK -function default_qr_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_HouseholderQR(; kwargs...) -end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index c49e6a2b..19f6dc00 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -51,14 +51,8 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -for f in (:schur_full, :schur_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end +for f in (:schur_full!, :schur_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 1c5d7e3a..e6f9021a 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -90,27 +90,22 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- -for f in (:svd_full, :svd_compact, :svd_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_svd_algorithm(A; kwargs...) - end - end +default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...) +function default_svd_algorithm(T::Type; kwargs...) + throw(MethodError(default_svd_algorithm, (T,))) +end +function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_DivideAndConquer(; kwargs...) end -function select_algorithm(::typeof(svd_trunc), A, alg; kwargs...) - return select_algorithm(svd_trunc!, A, alg; kwargs...) +for f in (:svd_full!, :svd_compact!, :svd_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) + end end -function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) + +function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:YALAPACK.BlasMat} alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end - -# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` -function default_svd_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_DivideAndConquer(; kwargs...) -end diff --git a/src/yalapack.jl b/src/yalapack.jl index 2bc08946..e8f4b685 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -16,6 +16,9 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK, using LinearAlgebra.BLAS: @blasfunc, libblastrampoline using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror +# type alias for matrices that are definitely supported by YALAPACK +const BlasMat{T<:BlasFloat} = StridedMatrix{T} + # LU factorisation for (getrf, getrs, elty) in ((:dgetrf_, :dgetrs_, :Float64), (:sgetrf_, :sgetrs_, :Float32), diff --git a/test/orthnull.jl b/test/orthnull.jl index 7f497eaf..b0004739 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -39,8 +39,8 @@ end function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) return check_input(right_orth!, parent(A), parent.(VC)) end -function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap) - return default_svd_algorithm(parent(A)) +function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) end function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, alg::LAPACK_SVDAlgorithm)