Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[compat]
DerivableInterfaces = "0.4.5"
FillArrays = "1.13.0"
GPUArraysCore = "0.2.0"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2.0"
Expand Down
253 changes: 214 additions & 39 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ end

arguments(a::KroneckerArray) = (a.a, a.b)
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
argument_types(a::KroneckerArray) = argument_types(typeof(a))
argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B)

function Base.print_array(io::IO, a::KroneckerArray)
Base.print_array(io, a.a)
Expand Down Expand Up @@ -234,6 +236,62 @@ function Base.:*(a::KroneckerArray, b::Number)
return a.a ⊗ (a.b * b)
end

function Base.:-(a::KroneckerArray)
return (-a.a) ⊗ a.b
end
for op in (:+, :-)
@eval begin
function Base.$op(a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
return $op(a.a, b.a) ⊗ a.b
elseif a.a == b.a
return a.a ⊗ $op(a.b, b.b)
end
return throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or secord arguments match.",
),
)
end
end
end

function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
dest.a .= a.a
dest.b .= a.b
return dest
end
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
map!(+, dest.a, a.a, b.a)
dest.b .= a.b
elseif a.a == b.a
dest.a .= a.a
map!(+, dest.b, a.b, b.b)
else
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
return dest
end
function Base.map!(
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= f.f.(f.x, a.a)
dest.b .= a.b
return dest
end
function Base.map!(
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= a.a
dest.b .= f.f.(a.b, f.x)
return dest
end

using LinearAlgebra:
LinearAlgebra,
Diagonal,
Expand Down Expand Up @@ -346,67 +404,138 @@ function LinearAlgebra.lq(a::KroneckerArray)
return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q)
end

function Base.:-(a::KroneckerArray)
using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
zero!(a.a)
zero!(a.b)
return a
end

using FillArrays: Eye
const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}

function Base.:*(a::Number, b::EyeKronecker)
return b.a ⊗ (a * b.b)
end
function Base.:*(a::Number, b::KroneckerEye)
return (a * b.a) ⊗ b.b
end
function Base.:*(a::Number, b::EyeEye)
return (a * b.a) ⊗ b.b
end
function Base.:*(a::EyeKronecker, b::Number)
return a.a ⊗ (a.b * b)
end
function Base.:*(a::KroneckerEye, b::Number)
return (a.a * b) ⊗ a.b
end
function Base.:*(a::EyeEye, b::Number)
return a.a ⊗ (a.b * b)
end

function Base.:-(a::EyeKronecker)
return a.a ⊗ (-a.b)
end
function Base.:-(a::KroneckerEye)
return (-a.a) ⊗ a.b
end
function Base.:-(a::EyeEye)
return (-a.a) ⊗ a.b
end
for op in (:+, :-)
@eval begin
function Base.$op(a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
return $op(a.a, b.a) ⊗ a.b
elseif a.a == b.a
return a.a ⊗ $op(a.b, b.b)
function Base.$op(a::EyeKronecker, b::EyeKronecker)
if a.a ≠ b.a
return throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or secord arguments match.",
),
)
end
return throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or secord arguments match.",
),
)
return a.a ⊗ $op(a.b, b.b)
end
function Base.$op(a::KroneckerEye, b::KroneckerEye)
if a.b ≠ b.b
return throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or secord arguments match.",
),
)
end
return $op(a.a, b.a) ⊗ a.b
end
function Base.$op(a::EyeEye, b::EyeEye)
if a.b ≠ b.b
return throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or secord arguments match.",
),
)
end
return $op(a.a, b.a) ⊗ a.b
end
end
end

function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
dest.a .= a.a
function Base.map!(::typeof(identity), dest::EyeKronecker, a::EyeKronecker)
dest.b .= a.b
return dest
end
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
map!(+, dest.a, a.a, b.a)
dest.b .= a.b
elseif a.a == b.a
dest.a .= a.a
map!(+, dest.b, a.b, b.b)
else
function Base.map!(::typeof(identity), dest::KroneckerEye, a::KroneckerEye)
dest.a .= a.a
return dest
end
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
if dest.a ≠ a.a ≠ b.a
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
map!(f, dest.b, a.b, b.b)
return dest
end
function Base.map!(
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= f.x .* a.a
dest.b .= a.b
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
if dest.b ≠ a.b ≠ b.b
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
map!(f, dest.a, a.a, b.a)
return dest
end
function Base.map!(
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= a.a
dest.b .= a.b .* f.x
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
return error("Can't write in-place.")
end
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
dest.b .= f.f.(f.x, a.b)
return dest
end

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
zero!(a.a)
zero!(a.b)
return a
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
dest.a .= f.f.(f.x, a.a)
return dest
end
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
dest.b .= f.f.(a.b, f.x)
return dest
end
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
dest.a .= f.f.(a.a, f.x)
return dest
end
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end

using MatrixAlgebraKit:
Expand Down Expand Up @@ -447,15 +576,61 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
b::B
end

using MatrixAlgebraKit:
copy_input,
eig_full,
eigh_full,
qr_compact,
qr_full,
left_polar,
lq_compact,
lq_full,
right_polar,
svd_compact,
svd_full

for f in [
:eig_full,
:eigh_full,
:qr_compact,
:qr_full,
:left_polar,
:lq_compact,
:lq_full,
:right_polar,
:svd_compact,
:svd_full,
]
@eval begin
function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix)
return copy_input($f, a.a) ⊗ copy_input($f, a.b)
end
end
end

for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
ff = Symbol("default_", f, "_algorithm")
@eval begin
function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...)
return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...))
function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...)
A1, A2 = argument_types(A)
return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...))
end
end
end

# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
function MatrixAlgebraKit.default_algorithm(
::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs...
)
return default_qr_algorithm(A; kwargs...)
end
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
function MatrixAlgebraKit.default_algorithm(
::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs...
)
return default_qr_algorithm(A; kwargs...)
end

for f in (
:eig_full!,
:eigh_full!,
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Expand All @@ -9,6 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8"
FillArrays = "1"
KroneckerArrays = "0.1"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2"
Expand Down
15 changes: 15 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using FillArrays: Eye
using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
using Test: @test, @testset
Expand Down Expand Up @@ -66,3 +67,17 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
@test collect(Q * R) ≈ collect(a)
@test collect(Q'Q) ≈ I
end

@testset "FillArrays.Eye" begin
a = Eye(2) ⊗ randn(3, 3)
@test size(a) == (6, 6)
@test a + a == Eye(2) ⊗ (2a.b)
@test 2a == Eye(2) ⊗ (2a.b)
@test a * a == Eye(2) ⊗ (a.b * a.b)

a = randn(3, 3) ⊗ Eye(2)
@test size(a) == (6, 6)
@test a + a == (2a.a) ⊗ Eye(2)
@test 2a == (2a.a) ⊗ Eye(2)
@test a * a == (a.a * a.a) ⊗ Eye(2)
end
Loading
Loading