Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.2.1"
version = "0.2.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
11 changes: 3 additions & 8 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
""" select_algorithm

function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return select_algorithm(f, typeof(A), alg; kwargs...)
end
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
function select_algorithm(f::F, A::T, alg::Alg=nothing; kwargs...) where {F,T,Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
Expand All @@ -99,7 +96,6 @@ function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F
throw(ArgumentError("Unknown alg $alg"))
end


@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}
Expand Down Expand Up @@ -194,11 +190,10 @@ macro functiondef(f)
end

# define fallbacks for algorithm selection
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
kwargs...) where {Alg,A}
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
return select_algorithm($f!, A, alg; kwargs...)
end
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
@inline function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end

Expand Down
3 changes: 1 addition & 2 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ for f in (:eig_full!, :eig_vals!)
end
end

function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end
3 changes: 1 addition & 2 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ for f in (:eigh_full!, :eigh_vals!)
end
end

function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
end
6 changes: 2 additions & 4 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
end

for f in (:lq_full!, :lq_compact!, :lq_null!)
@eval begin
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_lq_algorithm(A; kwargs...)
end
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_lq_algorithm(A; kwargs...)
end
end
5 changes: 1 addition & 4 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ end
# Algorithm selection
# -------------------
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
function default_polar_algorithm(T::Type; kwargs...)
throw(MethodError(default_polar_algorithm, (T,)))
end
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
function default_polar_algorithm(::Type{T}; kwargs...) where {T}
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
end

Expand Down
7 changes: 2 additions & 5 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
end

for f in (:qr_full!, :qr_compact!, :qr_null!)
@eval begin
function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return default_qr_algorithm(A; kwargs...)
end
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_qr_algorithm(A; kwargs...)
end
end
3 changes: 1 addition & 2 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
end
end

function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
end
Loading