Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.8"
version = "0.1.9"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[compat]
DerivableInterfaces = "0.4.5"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
GPUArraysCore = "0.2.0"
LinearAlgebra = "1.10"
Expand Down
272 changes: 231 additions & 41 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ end
function Base.iszero(a::KroneckerArray)
return iszero(a.a) || iszero(a.b)
end
function Base.isreal(a::KroneckerArray)
return isreal(a.a) && isreal(a.b)
end
function Base.inv(a::KroneckerArray)
return inv(a.a) ⊗ inv(a.b)
end
Expand All @@ -270,6 +273,9 @@ end
function Base.:*(a::KroneckerArray, b::Number)
return a.a ⊗ (a.b * b)
end
function Base.:/(a::KroneckerArray, b::Number)
return a * inv(b)
end

function Base.:-(a::KroneckerArray)
return (-a.a) ⊗ a.b
Expand All @@ -291,26 +297,82 @@ for op in (:+, :-)
end
end

using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
return KroneckerStyle{N,a,b}()
end
function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N}
return KroneckerStyle{N}(a, b)
end
function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
end
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
end
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
return KroneckerStyle{N}(
BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b)
)
end
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
ax_a = map(ax -> ax.product.a, axes(bc))
ax_b = map(ax -> ax.product.b, axes(bc))
bc_a = Broadcasted(A, nothing, (), ax_a)
bc_b = Broadcasted(B, nothing, (), ax_b)
a = similar(bc_a, elt)
b = similar(bc_b, elt)
return a ⊗ b
end
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle})
return throw(
ArgumentError(
"Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
),
)
end

function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
return throw(
ArgumentError(
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
),
)
end
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
return throw(
ArgumentError(
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
),
)
end
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
dest.a .= a.a
dest.b .= a.b
return dest
end
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
map!(+, dest.a, a.a, b.a)
dest.b .= a.b
elseif a.a == b.a
dest.a .= a.a
map!(+, dest.b, a.b, b.b)
else
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
for f in [:+, :-]
@eval begin
function Base.map!(
::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray
)
if a.b == b.b
map!($f, dest.a, a.a, b.a)
dest.b .= a.b
elseif a.a == b.a
dest.a .= a.a
map!($f, dest.b, a.b, b.b)
else
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
return dest
end
end
return dest
end
function Base.map!(
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
Expand All @@ -326,6 +388,16 @@ function Base.map!(
dest.b .= f.f.(a.b, f.x)
return dest
end
function Base.map!(
f::Base.Fix2{typeof(/),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
return map!(Base.Fix2(*, inv(f.x)), dest, a)
end
function Base.map!(::typeof(conj), dest::KroneckerArray, a::KroneckerArray)
dest.a .= conj.(a.a)
dest.b .= conj.(a.b)
return dest
end

using LinearAlgebra:
LinearAlgebra,
Expand All @@ -343,9 +415,10 @@ using LinearAlgebra:
svd,
svdvals,
tr
diagonal(a::AbstractArray) = Diagonal(a)
function diagonal(a::KroneckerArray)
return Diagonal(a.a) ⊗ Diagonal(a.b)

using DiagonalArrays: DiagonalArrays, diagonal
function DiagonalArrays.diagonal(a::KroneckerArray)
return diagonal(a.a) ⊗ diagonal(a.b)
end

function Base.:*(a::KroneckerArray, b::KroneckerArray)
Expand All @@ -372,6 +445,23 @@ function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
return norm(a.a, p) ⊗ norm(a.b, p)
end

function Base.real(a::KroneckerArray)
if iszero(imag(a.a)) || iszero(imag(a.b))
return real(a.a) ⊗ real(a.b)
elseif iszero(real(a.a)) || iszero(real(a.b))
return -imag(a.a) ⊗ imag(a.b)
end
return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b)
end
function Base.imag(a::KroneckerArray)
if iszero(imag(a.a)) || iszero(real(a.b))
return real(a.a) ⊗ imag(a.b)
elseif iszero(real(a.a)) || iszero(imag(a.b))
return imag(a.a) ⊗ real(a.b)
end
return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b)
end

using MatrixAlgebraKit: MatrixAlgebraKit, diagview
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
return diagview(a.a) ⊗ diagview(a.b)
Expand Down Expand Up @@ -506,6 +596,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::EyeKronecker)
zero!(a.b)
return a
end
function DerivableInterfaces.zero!(a::KroneckerEye)
zero!(a.a)
return a
end
function DerivableInterfaces.zero!(a::EyeEye)
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
end

function Base.:*(a::Number, b::EyeKronecker)
return b.a ⊗ (a * b.b)
end
Expand Down Expand Up @@ -580,29 +683,44 @@ end
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
if dest.a ≠ a.a ≠ b.a
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
for f in [:+, :-]
@eval begin
function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
if dest.a ≠ a.a ≠ b.a
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
map!($f, dest.b, a.b, b.b)
return dest
end
function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
if dest.b ≠ a.b ≠ b.b
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
map!($f, dest.a, a.a, b.a)
return dest
end
function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye)
return error("Can't write in-place.")
end
end
map!(f, dest.b, a.b, b.b)
end
function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker)
map!(f, dest.b, a.b)
return dest
end
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
if dest.b ≠ a.b ≠ b.b
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
map!(f, dest.a, a.a, b.a)
function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye)
map!(f, dest.a, a.a)
return dest
end
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
Expand Down Expand Up @@ -812,6 +930,74 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
function Base.similar(
a::SquareEyeKronecker,
elt::Type,
axs::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
)
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_a = (only(unique(ax_a)),)
return Eye{elt}(eye_ax_a) ⊗ similar(a.b, elt, ax_b)
end
function Base.similar(
a::KroneckerSquareEye,
elt::Type,
axs::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
)
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_b = (only(unique(ax_b)),)
return similar(a.a, elt, ax_a) ⊗ Eye{elt}(eye_ax_b)
end
function Base.similar(
a::SquareEyeSquareEye,
elt::Type,
axs::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
)
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_a = (only(unique(ax_a)),)
eye_ax_b = (only(unique(ax_b)),)
return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b)
end

function Base.similar(
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
) where {A}
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_a = (only(unique(ax_a)),)
return Eye{eltype(arrayt)}(eye_ax_a) ⊗ similar(A, ax_b)
end
function Base.similar(
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
) where {A}
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_b = (only(unique(ax_b)),)
return similar(A, ax_a) ⊗ Eye{eltype(arrayt)}(eye_ax_b)
end
function Base.similar(
arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}}
)
elt = eltype(arrayt)
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_a = (only(unique(ax_a)),)
eye_ax_b = (only(unique(ax_b)),)
return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b)
end

struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
kwargs::KWargs
end
Expand Down Expand Up @@ -884,8 +1070,6 @@ for f in [:left_null!, :right_null!]
end
end
for f in [
:eig_full!,
:eigh_full!,
:qr_compact!,
:qr_full!,
:left_orth!,
Expand All @@ -900,10 +1084,14 @@ for f in [
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
end
end
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a))
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a))
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a)
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a)
for f in [:svd_compact!, :svd_full!]
@eval begin
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a)
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a)
end
end

Expand Down Expand Up @@ -987,10 +1175,12 @@ function MatrixAlgebraKit.right_null!(
return throw(MethodError(right_null!, (a, F)))
end

for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a)
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a)
for f in [:eigh_vals!, svd_vals!]
@eval begin
_initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a)
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a)
_initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a))
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a))
end
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
Loading
Loading