Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.15"
version = "0.3.16"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -11,11 +11,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"

[weakdeps]
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[extensions]
DiagonalArraysMatrixAlgebraKitExt = "MatrixAlgebraKit"

[compat]
ArrayLayouts = "1.10.4"
DerivableInterfaces = "0.5.5"
FillArrays = "1.13.0"
LinearAlgebra = "1.10.0"
MapBroadcast = "0.1.10"
MatrixAlgebraKit = "0.2"
SparseArraysBase = "0.7.2"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
module DiagonalArraysMatrixAlgebraKitExt

using DiagonalArrays:
AbstractDiagonalMatrix,
DeltaMatrix,
DiagonalMatrix,
ScaledDeltaMatrix,
δ,
diagview,
dual,
issquare
using LinearAlgebra: LinearAlgebra, isdiag, ishermitian
using MatrixAlgebraKit:
MatrixAlgebraKit,
AbstractAlgorithm,
check_input,
default_qr_algorithm,
eig_full,
eig_full!,
eig_vals,
eig_vals!,
eigh_full,
eigh_full!,
eigh_vals,
eigh_vals!,
left_null,
left_null!,
left_orth,
left_orth!,
left_polar,
left_polar!,
lq_compact,
lq_compact!,
lq_full,
lq_full!,
qr_compact,
qr_compact!,
qr_full,
qr_full!,
right_null,
right_null!,
right_orth,
right_orth!,
right_polar,
right_polar!,
svd_compact,
svd_compact!,
svd_full,
svd_full!,
svd_vals,
svd_vals!

abstract type AbstractDiagonalAlgorithm <: AbstractAlgorithm end

struct DeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm
kwargs::KWargs
end
DeltaAlgorithm(; kwargs...) = DeltaAlgorithm((; kwargs...))

struct ScaledDeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm
kwargs::KWargs
end
ScaledDeltaAlgorithm(; kwargs...) = ScaledDeltaAlgorithm((; kwargs...))

for f in [
:eig_full,
:eig_vals,
:eigh_full,
:eigh_vals,
:qr_compact,
:qr_full,
:left_null,
:left_orth,
:left_polar,
:lq_compact,
:lq_full,
:right_null,
:right_orth,
:right_polar,
:svd_compact,
:svd_full,
:svd_vals,
]
@eval begin
MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractDiagonalMatrix) = copy(a)
end
end

for f in [
:default_eig_algorithm,
:default_eigh_algorithm,
:default_lq_algorithm,
:default_qr_algorithm,
:default_polar_algorithm,
:default_svd_algorithm,
]
@eval begin
function MatrixAlgebraKit.$f(::Type{<:DeltaMatrix}; kwargs...)
return DeltaAlgorithm(; kwargs...)
end
function MatrixAlgebraKit.$f(::Type{<:ScaledDeltaMatrix}; kwargs...)
return ScaledDeltaAlgorithm(; kwargs...)
end
end
end

for f in [
:eig_full!,
:eig_vals!,
:eigh_full!,
:eigh_vals!,
:left_null!,
:left_orth!,
:left_polar!,
:lq_compact!,
:lq_full!,
:qr_compact!,
:qr_full!,
:right_null!,
:right_orth!,
:right_polar!,
:svd_compact!,
:svd_full!,
:svd_vals!,
]
for Alg in [:ScaledDeltaAlgorithm, :DeltaAlgorithm]
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a, alg::$Alg)
return nothing
end
end
end
end

for f in [
:left_null!,
:left_orth!,
:left_polar!,
:lq_compact!,
:lq_full!,
:qr_compact!,
:qr_full!,
:right_null!,
:right_orth!,
:right_polar!,
:svd_compact!,
:svd_full!,
:svd_vals!,
]
@eval begin
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::DeltaAlgorithm)
@assert size(a, 1) == size(a, 2)
@assert isdiag(a)
@assert all(isone, diagview(a))
return nothing
end
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::ScaledDeltaAlgorithm)
@assert size(a, 1) == size(a, 2)
@assert isdiag(a)
@assert allequal(diagview(a))
return nothing
end
end
end
for f in [:eig_full!, :eig_vals!, :eigh_full!, :eigh_vals!]
@eval begin
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::DeltaAlgorithm)
@assert issquare(a)
@assert isdiag(a)
@assert all(isone, diagview(a))
return nothing
end
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::ScaledDeltaAlgorithm)
@assert issquare(a)
@assert isdiag(a)
@assert allequal(diagview(a))
return nothing
end
end
end

# eig
for Alg in [:DeltaAlgorithm, :ScaledDeltaAlgorithm]
@eval begin
function MatrixAlgebraKit.eig_full!(a, F, alg::$Alg)
check_input(eig_full!, a, F, alg)
d = complex(a)
v = δ(complex(eltype(a)), axes(a))
return (d, v)
end
function MatrixAlgebraKit.eigh_full!(a, F, alg::$Alg)
check_input(eigh_full!, a, F, alg)
ishermitian(a) || throw(ArgumentError("Matrix must be Hermitian"))
d = real(a)
v = δ(eltype(a), axes(a))
return (d, v)
end
function MatrixAlgebraKit.eig_vals!(a, F, alg::$Alg)
check_input(eig_vals!, a, F, alg)
return complex(diagview(a))
end
function MatrixAlgebraKit.eigh_vals!(a, F, alg::$Alg)
check_input(eigh_vals!, a, F, alg)
return real(diagview(a))
end
end
end

# svd
for f in [:svd_compact!, :svd_full!]
@eval begin
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
check_input($f, a, F, alg)
u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
s = real(a)
v = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2)))
return (u, s, v)
end
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
check_input($f, a, F, alg)
diagvalue = only(unique(diagview(a)))
u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
s = abs(diagvalue) * δ(Bool, axes(a))
# Sign is applied arbitarily to `v`, alternatively
# we could apply it to `u`.
v = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2)))
return (u, s, v)
end
end
end
function MatrixAlgebraKit.svd_vals!(a, F, alg::DeltaAlgorithm)
check_input(svd_vals!, a, F, alg)
# Using `real` instead of `abs.` helps to preserve `Ones`.
return real(diagview(a))
end
function MatrixAlgebraKit.svd_vals!(a, F, alg::ScaledDeltaAlgorithm)
check_input(svd_vals!, a, F, alg)
return abs.(diagview(a))
end

# orth
# left_orth is implicitly defined by defining backends like
# qr_compact and left_polar.
for f in [:left_polar!, :qr_compact!, :qr_full!]
@eval begin
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
check_input($f, a, F, alg)
q = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
r = copy(a)
return (q, r)
end
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
check_input($f, a, F, alg)
diagvalue = only(unique(diagview(a)))
q = sign(diagvalue) * δ(Bool, (axes(a, 1), dual(axes(a, 1))))
# We're a bit pessimistic about the element type for type stability,
# since in the future we might provide the option to do non-positive QR.
r = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a))
return (q, r)
end
end
end
# right_orth is implicitly defined by defining backends like
# lq_compact and right_polar.
for f in [:right_polar!, :lq_compact!, :lq_full!]
@eval begin
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
check_input($f, a, F, alg)
l = copy(a)
q = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2)))
return (l, q)
end
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
check_input($f, a, F, alg)
diagvalue = only(unique(diagview(a)))
# We're a bit pessimistic about the element type for type stability,
# since in the future we might provide the option to do non-positive LQ.
l = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a))
q = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2)))
return (l, q)
end
end
end

# null
for T in [:DeltaMatrix, :ScaledDeltaMatrix]
@eval begin
# TODO: Right now we can't overload `left_null!` on an algorithm,
# make a PR to MatrixAlgebraKit.jl to allow that.
function MatrixAlgebraKit.left_null!(a::$T, F)
check_input(left_null!, a, F, default_qr_algorithm(a))
return error("Not implemented.")
end
# TODO: Right now we can't overload `right_null!` on an algorithm,
# make a PR to MatrixAlgebraKit.jl to allow that.
function MatrixAlgebraKit.right_null!(a::$T, F)
check_input(right_null!, a, F, default_qr_algorithm(a))
return error("Not implemented.")
end
end
end

end
1 change: 1 addition & 0 deletions src/DiagonalArrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module DiagonalArrays

include("dual.jl")
include("diaginterface/diaginterface.jl")
include("diaginterface/diagindex.jl")
include("diaginterface/diagindices.jl")
Expand Down
18 changes: 18 additions & 0 deletions src/abstractdiagonalarray/abstractdiagonalarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
using SparseArraysBase: AbstractSparseArray

abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end
const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2}
const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1}

using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric
LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a)
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number})
return issquare(a) && isreal(diagview(a))
end
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix)
return issquare(a) && all(ishermitian, diagview(a))
end
LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix{<:Number}) = issquare(a)
function LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix)
return issquare(a) && all(issymmetric, diagview(a))
end
function LinearAlgebra.isposdef(a::AbstractDiagonalMatrix)
return issquare(a) && all(isposdef, diagview(a))
end
8 changes: 6 additions & 2 deletions src/diagonalarray/diagonalarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,13 @@ function Base.similar(a::DiagonalArray, unstored::Unstored)
return DiagonalArray(undef, unstored)
end

# This definition is helpful for immutable diagonals
# These definitions are helpful for immutable diagonals
# such as FillArrays.
Base.copy(a::DiagonalArray) = DiagonalArray(copy(diagview(a)), axes(a))
for f in [:complex, :copy, :imag, :real]
@eval begin
Base.$f(a::DiagonalArray) = DiagonalArray($f(diagview(a)), axes(a))
end
end

# DiagonalArrays interface.
diagview(a::DiagonalArray) = a.diag
Expand Down
2 changes: 1 addition & 1 deletion src/diagonalarray/diagonalmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra: LinearAlgebra

function mul_diagviews(a1, a2)
# TODO: Compare that duals are equal, or define a function to overload.
axes(a1, 2) == axes(a2, 1) || throw(
dual(axes(a1, 2)) == axes(a2, 1) || throw(
DimensionMismatch(
lazy"Incompatible dimensions for multiplication: $(axes(a1)) and $(axes(a2))"
),
Expand Down
3 changes: 3 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TODO: Define `TensorProducts.dual`.
dual(x) = x
issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2)))
Loading
Loading