Skip to content
Merged
5 changes: 5 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export eigh_full, eigh_vals, eigh_trunc
export eigh_full!, eigh_vals!, eigh_trunc!
export eig_full, eig_vals, eig_trunc
export eig_full!, eig_vals!, eig_trunc!
export gen_eig_full, gen_eig_vals
export gen_eig_full!, gen_eig_vals!
export schur_full, schur_vals
export schur_full!, schur_vals!
export left_polar, right_polar
Expand All @@ -45,6 +47,7 @@ include("common/safemethods.jl")
include("common/view.jl")
include("common/regularinv.jl")
include("common/matrixproperties.jl")
include("common/gauge.jl")

include("yalapack.jl")
include("algorithms.jl")
Expand All @@ -54,6 +57,7 @@ include("interface/lq.jl")
include("interface/svd.jl")
include("interface/eig.jl")
include("interface/eigh.jl")
include("interface/gen_eig.jl")
include("interface/schur.jl")
include("interface/polar.jl")
include("interface/orthnull.jl")
Expand All @@ -64,6 +68,7 @@ include("implementations/lq.jl")
include("implementations/svd.jl")
include("implementations/eig.jl")
include("implementations/eigh.jl")
include("implementations/gen_eig.jl")
include("implementations/schur.jl")
include("implementations/polar.jl")
include("implementations/orthnull.jl")
Expand Down
133 changes: 110 additions & 23 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ explicitly.
New types should prefer to register their default algorithms in the type domain.
""" default_algorithm
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...)
# avoid infinite recursion:
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
throw(MethodError(default_algorithm, (f, T)))
end
function default_algorithm(f::F, ::Type{TA}, ::Type{TB}; kwargs...) where {F,TA,TB}
throw(MethodError(default_algorithm, (f, TA, TB)))
end

@doc """
copy_input(f, A)
Expand Down Expand Up @@ -153,28 +157,8 @@ macro algdef(name)
end)
end

"""
@functiondef f

Convenience macro to define the boilerplate code that dispatches between several versions of `f` and `f!`.
By default, this enables the following signatures to be defined in terms of
the final `f!(A, out, alg::Algorithm)`.

```julia
f(A; kwargs...)
f(A, alg::Algorithm)
f!(A, [out]; kwargs...)
f!(A, alg::Algorithm)
```

See also [`copy_input`](@ref), [`select_algorithm`](@ref) and [`initialize_output`](@ref).
"""
macro functiondef(f)
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
f! = Symbol(f, :!)

ex = quote
# out of place to inplace
function _arg_expr(::Val{1}, f, f!)
return quote # out of place to inplace
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

Expand Down Expand Up @@ -215,7 +199,110 @@ macro functiondef(f)
# copy documentation to both functions
Core.@__doc__ $f, $f!
end
return esc(ex)
end

function _arg_expr(::Val{2}, f, f!)
return quote
# out of place to inplace
$f(A, B; kwargs...) = $f!(copy_input($f, A, B)...; kwargs...)
$f(A, B, alg::AbstractAlgorithm) = $f!(copy_input($f, A, B)..., alg)

# fill in arguments
function $f!(A, B; alg=nothing, kwargs...)
return $f!(A, B, select_algorithm($f!, (A, B), alg; kwargs...))
end
function $f!(A, B, out; alg=nothing, kwargs...)
return $f!(A, B, out, select_algorithm($f!, (A, B), alg; kwargs...))
end
function $f!(A, B, alg::AbstractAlgorithm)
return $f!(A, B, initialize_output($f!, A, B, alg), alg)
end

# define fallbacks for algorithm selection
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
return select_algorithm($f!, A, alg; kwargs...)
end
# define default algorithm fallbacks for out-of-place functions
# in terms of the corresponding in-place function
@inline function default_algorithm(::typeof($f), A, B; kwargs...)
return default_algorithm($f!, A, B; kwargs...)
end
# define default algorithm fallbacks for out-of-place functions
# in terms of the corresponding in-place function for types,
# in principle this is covered by the definition above but
# it is necessary to avoid ambiguity errors with the generic definitions:
# ```julia
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
# throw(MethodError(default_algorithm, (f, T)))
# end
# ```
@inline function default_algorithm(::typeof($f), ::Type{A}, ::Type{B}; kwargs...) where {A, B}
return default_algorithm($f!, A, B; kwargs...)
end

# copy documentation to both functions
Core.@__doc__ $f, $f!
end
end

"""
@functiondef [n_args=1] f

Convenience macro to define the boilerplate code that dispatches between several versions of `f` and `f!`.
By default, `f` accepts a single argument `A`. This enables the following signatures to be defined in terms of
the final `f!(A, out, alg::Algorithm)`.

```julia
f(A; kwargs...)
f(A, alg::Algorithm)
f!(A, [out]; kwargs...)
f!(A, alg::Algorithm)
```

The number of inputs can be set with the `n_args` keyword
argument, so that

```julia
@functiondef n_args=2 f
```

would create

```julia
f(A, B; kwargs...)
f(A, B, alg::Algorithm)
f!(A, B, [out]; kwargs...)
f!(A, B, alg::Algorithm)
```

See also [`copy_input`](@ref), [`select_algorithm`](@ref) and [`initialize_output`](@ref).
"""
macro functiondef(args...)
kwargs = map(args[1:end-1]) do kwarg
if kwarg isa Symbol
:($kwarg = $kwarg)
elseif Meta.isexpr(kwarg, :(=))
kwarg
else
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
end
end
isempty(kwargs) || length(kwargs) == 1 || throw(ArgumentError("Only one keyword argument to `@functiondef` is supported"))
f_n_args = 1 # default
if length(kwargs) == 1
kwarg = only(kwargs) # only one kwarg is currently supported, TODO modify if we support more
key::Symbol, val = kwarg.args
key === :n_args || throw(ArgumentError("Unsupported keyword argument $key to `@functiondef`"))
(isa(val, Integer) && val > 0) || throw(ArgumentError("`n_args` keyword argument to `@functiondef` should be an integer > 0"))
f_n_args = val
end

f = args[end]
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
f! = Symbol(f, :!)

return esc(_arg_expr(Val(f_n_args), f, f!))
end

"""
Expand Down
8 changes: 8 additions & 0 deletions src/common/gauge.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function gaugefix!(V::AbstractMatrix)
for j in axes(V, 2)
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
@inbounds v .*= s
end
return V
end
6 changes: 1 addition & 5 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
YALAPACK.geevx!(A, D.diag, V; alg.kwargs...)
end
# TODO: make this controllable using a `gaugefix` keyword argument
for j in 1:size(V, 2)
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
v .*= s
end
V = gaugefix!(V)
return D, V
end

Expand Down
85 changes: 85 additions & 0 deletions src/implementations/gen_eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Inputs
# ------
function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A), copy!(similar(B, float(eltype(B))), B)
end
function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix)
return copy_input(gen_eig_full, A, B)
end

function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
W, V = WV
@assert W isa Diagonal && V isa AbstractMatrix
@check_size(W, (ma, ma))
@check_scalar(W, A, complex)
@check_scalar(W, B, complex)
@check_size(V, (ma, ma))
@check_scalar(V, A, complex)
@check_scalar(V, B, complex)
return nothing
end
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match"))
@assert W isa AbstractVector
@check_size(W, (na,))
@check_scalar(W, A, complex)
@check_scalar(W, B, complex)
return nothing
end

# Outputs
# -------
function initialize_output(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
n = size(A, 1) # square check will happen later
Tc = complex(eltype(A))
W = Diagonal(similar(A, Tc, n))
V = similar(A, Tc, (n, n))
return (W, V)
end
function initialize_output(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, ::LAPACK_EigAlgorithm)
n = size(A, 1) # square check will happen later
Tc = complex(eltype(A))
D = similar(A, Tc, n)
return D
end

# Implementation
# --------------
# actual implementation
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_full!, A, B, WV)
W, V = WV
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A)))
else # alg isa LAPACK_Expert
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
end
# TODO: make this controllable using a `gaugefix` keyword argument
V = gaugefix!(V)
return W, V
end

function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_vals!, A, B, W)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments"))
YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A)))
else # alg isa LAPACK_Expert
throw(ArgumentError("LAPACK_Expert is not supported for ggev"))
end
return W
end
10 changes: 0 additions & 10 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
# Eig API
# -------
# TODO: export? or not export but mark as public ?
function eig!(A::AbstractMatrix, args...; kwargs...)
return eig_full!(A, args...; kwargs...)
end
function eig(A::AbstractMatrix, args...; kwargs...)
return eig_full(A, args...; kwargs...)
end

# Eig functions
# -------------

Expand Down
69 changes: 69 additions & 0 deletions src/interface/gen_eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Gen Eig functions
# -------------

# TODO: kwargs for sorting eigenvalues?

docs_gen_eig_note = """
Note that [`gen_eig_full`](@ref) and its variants do not assume additional structure on the inputs,
and therefore will always return complex eigenvalues and eigenvectors. For the real
generalized eigenvalue decomposition is not yet supported.
"""

# TODO: do we need "full"?
"""
gen_eig_full(A, B; kwargs...) -> W, V
gen_eig_full(A, B, alg::AbstractAlgorithm) -> W, V
gen_eig_full!(A, B, [WV]; kwargs...) -> W, V
gen_eig_full!(A, B, [WV], alg::AbstractAlgorithm) -> W, V

Compute the full generalized eigenvalue decomposition of the square matrices `A` and `B`,
such that `A * V = B * V * W`, where the invertible matrix `V` contains the generalized eigenvectors
and the diagonal matrix `W` contains the associated generalized eigenvalues.

!!! note
The bang method `gen_eig_full!` optionally accepts the output structure and
possibly destroys the input matrices `A` and `B`.
Always use the return value of the function as it may not always be
possible to use the provided `WV` as output.

!!! note
$(docs_gen_eig_note)

See also [`gen_eig_vals(!)`](@ref eig_vals).
"""
@functiondef n_args=2 gen_eig_full

"""
gen_eig_vals(A, B; kwargs...) -> W
gen_eig_vals(A, B, alg::AbstractAlgorithm) -> W
gen_eig_vals!(A, B, [W]; kwargs...) -> W
gen_eig_vals!(A, B, [W], alg::AbstractAlgorithm) -> W

Compute the list of generalized eigenvalues of `A` and `B`.

!!! note
The bang method `gen_eig_vals!` optionally accepts the output structure and
possibly destroys the input matrices `A` and `B`. Always use the return
value of the function as it may not always be possible to use the
provided `W` as output.

!!! note
$(docs_gen_eig_note)

See also [`gen_eig_full(!)`](@ref gen_eig_full).
"""
@functiondef n_args=2 gen_eig_vals

# Algorithm selection
# -------------------
default_gen_eig_algorithm(A, B; kwargs...) = default_gen_eig_algorithm(typeof(A), typeof(B); kwargs...)
default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA, TB} = throw(MethodError(default_gen_eig_algorithm, (TA,TB)))
function default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA<:YALAPACK.BlasMat,TB<:YALAPACK.BlasMat}
return LAPACK_Simple(; kwargs...)
end

for f in (:gen_eig_full!, :gen_eig_vals!)
@eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B}
return default_gen_eig_algorithm(A, B; kwargs...)
end
end
Loading
Loading