Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -36,4 +36,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
"ChainRulesTestUtils", "StableRNGs", "Zygote"]
109 changes: 60 additions & 49 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))

Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
This can be obtained both for values `A` or types `A`.

If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.

Expand All @@ -73,62 +74,62 @@
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
"""
function select_algorithm end
""" select_algorithm

function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return _select_algorithm(f, A, alg; kwargs...)
return select_algorithm(f, typeof(A), alg; kwargs...)
end
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
elseif alg isa Type
return alg(; kwargs...)
elseif alg isa NamedTuple
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
elseif alg isa AbstractAlgorithm
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return alg
end

function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
return default_algorithm(f, A; kwargs...)
end
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
return Algorithm{alg}(; kwargs...)
end
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
return Alg(; kwargs...)
end
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
end
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
return alg
end
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
return throw(ArgumentError("Unknown alg $alg"))
throw(ArgumentError("Unknown alg $alg"))

Check warning on line 99 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L99

Added line #L99 was not covered by tests
end


@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}

Select the default algorithm for a given factorization function `f` and input `A`.
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
explicitly.
"""
function default_algorithm end
New types should prefer to register their default algorithms in the type domain.
""" default_algorithm
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
# avoid infinite recursion:
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
throw(MethodError(default_algorithm, (f, T)))

Check warning on line 115 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end

@doc """
copy_input(f, A)

Preprocess the input `A` for a given function, such that it may be handled correctly later.
This may include a copy whenever the implementation would destroy the original matrix,
or a change of element type to something that is supported.
"""
function copy_input end
""" copy_input

@doc """
initialize_output(f, A, alg)

Whenever possible, allocate the destination for applying a given algorithm in-place.
If this is not possible, for example when the output size is not known a priori or immutable,
this function may return `nothing`.
"""
function initialize_output end
""" initialize_output

# Utility macros
# --------------
Expand Down Expand Up @@ -176,25 +177,35 @@
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
f! = Symbol(f, :!)

return esc(quote
# out of place to inplace
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
end

# copy documentation to both functions
Core.@__doc__ $f, $f!
end)
ex = quote
# out of place to inplace
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
end

# define fallbacks for algorithm selection
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
kwargs...) where {Alg,A}
return select_algorithm($f!, A, alg; kwargs...)
end
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_algorithm($f!, A; kwargs...)
end

# copy documentation to both functions
Core.@__doc__ $f, $f!
end
return esc(ex)
end

"""
Expand Down
29 changes: 11 additions & 18 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,20 @@

# Algorithm selection
# -------------------
for f in (:eig_full, :eig_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...)
default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,)))

Check warning on line 91 in src/interface/eig.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/eig.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
return LAPACK_Expert(; kwargs...)
end

function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
return select_algorithm(eig_trunc!, A, alg; kwargs...)
for f in (:eig_full!, :eig_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_eig_algorithm(A; kwargs...)
end
end
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)

function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end

# Default to LAPACK
function default_eig_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_Expert(; kwargs...)
end
31 changes: 13 additions & 18 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,22 @@

# Algorithm selection
# -------------------
for f in (:eigh_full, :eigh_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
end
default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs...)
function default_eigh_algorithm(T::Type; kwargs...)
throw(MethodError(default_eigh_algorithm, (T,)))

Check warning on line 91 in src/interface/eigh.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/eigh.jl#L89-L91

Added lines #L89 - L91 were not covered by tests
end
function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end

function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
for f in (:eigh_full!, :eigh_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_eigh_algorithm(A; kwargs...)
end
end
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)

function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
end

# Default to LAPACK
function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end
21 changes: 10 additions & 11 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,18 @@

# Algorithm selection
# -------------------
for f in (:lq_full, :lq_compact, :lq_null)
f! = Symbol(f, :!)
default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
function default_lq_algorithm(T::Type; kwargs...)
throw(MethodError(default_lq_algorithm, (T,)))

Check warning on line 73 in src/interface/lq.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/lq.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end
function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
return LAPACK_HouseholderLQ(; kwargs...)
end

for f in (:lq_full!, :lq_compact!, :lq_null!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_lq_algorithm(A; kwargs...)
end
end
end

# Default to LAPACK
function default_lq_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_HouseholderLQ(; kwargs...)
end
23 changes: 10 additions & 13 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,16 @@

# Algorithm selection
# -------------------
for f in (:left_polar, :right_polar)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_polar_algorithm(A; kwargs...)
end
end
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
function default_polar_algorithm(T::Type; kwargs...)
throw(MethodError(default_polar_algorithm, (T,)))

Check warning on line 65 in src/interface/polar.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/polar.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
end
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
end

# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`
function default_polar_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return PolarViaSVD(default_svd_algorithm(A; kwargs...))
for f in (:left_polar!, :right_polar!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_polar_algorithm(A; kwargs...)
end
end
22 changes: 11 additions & 11 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@

# Algorithm selection
# -------------------
for f in (:qr_full, :qr_compact, :qr_null)
f! = Symbol(f, :!)
default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
function default_qr_algorithm(T::Type; kwargs...)
throw(MethodError(default_qr_algorithm, (T,)))

Check warning on line 73 in src/interface/qr.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/qr.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end
function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
return LAPACK_HouseholderQR(; kwargs...)
end

for f in (:qr_full!, :qr_compact!, :qr_null!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return default_qr_algorithm(A; kwargs...)
end
end
end

# Default to LAPACK
function default_qr_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_HouseholderQR(; kwargs...)
end
12 changes: 3 additions & 9 deletions src/interface/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,8 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).

# Algorithm selection
# -------------------
for f in (:schur_full, :schur_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
for f in (:schur_full!, :schur_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_eig_algorithm(A; kwargs...)
end
end
Loading
Loading