Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Expand Down
189 changes: 176 additions & 13 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ end
arguments(a::CartesianProduct) = (a.a, a.b)
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]

function Base.show(io::IO, a::CartesianProduct)
print(io, a.a, " × ", a.b)
return nothing
end

×(a, b) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
Expand Down Expand Up @@ -130,6 +135,8 @@ function interleave(x::Tuple, y::Tuple)
xy = ntuple(i -> (x[i], y[i]), length(x))
return flatten(xy)
end
# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing:
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
Expand Down Expand Up @@ -183,6 +190,9 @@ function Base.getindex(a::KroneckerArray, i::Integer)
return a[CartesianIndices(a)[i]]
end

# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing
# in the n-dimensional case and use it to replace the matrix and vector cases:
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
return error("Not implemented.")
end
Expand Down Expand Up @@ -222,6 +232,10 @@ end
function Base.inv(a::KroneckerArray)
return inv(a.a) ⊗ inv(a.b)
end
using LinearAlgebra: LinearAlgebra, pinv
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...)
end
function Base.transpose(a::KroneckerArray)
return transpose(a.a) ⊗ transpose(a.b)
end
Expand Down Expand Up @@ -297,6 +311,7 @@ using LinearAlgebra:
Diagonal,
Eigen,
SVD,
det,
diag,
eigen,
eigvals,
Expand Down Expand Up @@ -335,9 +350,63 @@ end
function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
return norm(a.a, p) ⊗ norm(a.b, p)
end

using MatrixAlgebraKit: MatrixAlgebraKit, diagview
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
return diagview(a.a) ⊗ diagview(a.b)
end
function LinearAlgebra.diag(a::KroneckerArray)
return diag(a.a) ⊗ diag(a.b)
return copy(diagview(a.a)) ⊗ copy(diagview(a.b))
end

# Matrix functions
const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::KroneckerArray)
return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported."))
end
end
end

using LinearAlgebra: checksquare
function LinearAlgebra.det(a::KroneckerArray)
checksquare(a.a)
checksquare(a.b)
return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1)
end

function LinearAlgebra.svd(a::KroneckerArray)
Fa = svd(a.a)
Fb = svd(a.b)
Expand Down Expand Up @@ -690,18 +759,6 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
end
end

for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
@eval begin
function MatrixAlgebraKit.truncate!(
::typeof($f),
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
strategy::TruncationStrategy,
)
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
end
end
end

for f in [:left_orth!, :right_orth!]
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
Expand Down Expand Up @@ -941,4 +998,110 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
end
end

using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate!

struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
strategy::T
end

# Avoid instantiating the identity.
function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2})
return a.a ⊗ a.b[I[1].b, I[2].b]
end
function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
return a.a[I[1].a, I[2].a] ⊗ a.b
end
function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2})
return a
end

using FillArrays: OnesVector
const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}

function MatrixAlgebraKit.findtruncated(
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.a, prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.a == i, prods) == length(values.a)
end
return (:) × I_data
end
function MatrixAlgebraKit.findtruncated(
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.b, prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.b == i, prods) == length(values.b)
end
return I_data × (:)
end
function MatrixAlgebraKit.findtruncated(
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
)
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
end

for f in [:eig_trunc!, :eigh_trunc!]
@eval begin
function MatrixAlgebraKit.truncate!(
::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy
)
return truncate!($f, DV, KroneckerTruncationStrategy(strategy))
end
function MatrixAlgebraKit.truncate!(
::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy
)
I = findtruncated(diagview(D), strategy)
return (D[I, I], V[(:) × (:), I])
end
end
end

function MatrixAlgebraKit.truncate!(
f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy
)
return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy))
end
function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::NTuple{3,KroneckerMatrix},
strategy::KroneckerTruncationStrategy,
)
I = findtruncated(diagview(S), strategy)
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
end

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::SquareEyeKronecker)
return a.a ⊗ $f(a.b)
end
function Base.$f(a::KroneckerSquareEye)
return $f(a.a) ⊗ a.b
end
function Base.$f(a::SquareEyeSquareEye)
return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported."))
end
end
end

function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...)
return a.a ⊗ pinv(a.b; kwargs...)
end
function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...)
return pinv(a.a; kwargs...) ⊗ a.b
end
function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...)
return a
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Expand All @@ -16,6 +17,7 @@ KroneckerArrays = "0.1"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
Test = "1.10"
TestExtras = "0.3"
85 changes: 82 additions & 3 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using FillArrays: Eye
using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
using Test: @test, @testset
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr
using StableRNGs: StableRNG
using Test: @test, @test_broken, @test_throws, @testset

const elts = (Float32, Float64, ComplexF32, ComplexF64)
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
Expand Down Expand Up @@ -35,7 +36,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
@test iszero(a - a)
@test collect(a + c) ≈ collect(a) + collect(c)
@test collect(b + c) ≈ collect(b) + collect(c)
for f in (transpose, adjoint, inv)
for f in (transpose, adjoint, inv, pinv)
@test collect(f(a)) ≈ f(collect(a))
end
@test tr(a) ≈ tr(collect(a))
Expand Down Expand Up @@ -66,9 +67,25 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
Q, R = qr(a)
@test collect(Q * R) ≈ collect(a)
@test collect(Q'Q) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
@test det(a) ≈ det(collect(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
for f in KroneckerArrays.MATRIX_FUNCTIONS
@eval begin
@test_throws ArgumentError $f($a)
end
end
end

@testset "FillArrays.Eye" begin
MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS
if VERSION < v"1.11-"
# `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11.
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
end

a = Eye(2) ⊗ randn(3, 3)
@test size(a) == (6, 6)
@test a + a == Eye(2) ⊗ (2a.b)
Expand All @@ -80,4 +97,66 @@ end
@test a + a == (2a.a) ⊗ Eye(2)
@test 2a == (2a.a) ⊗ Eye(2)
@test a * a == (a.a * a.a) ⊗ Eye(2)

# Eye ⊗ A
rng = StableRNG(123)
a = Eye(2) ⊗ randn(rng, 3, 3)
for f in MATRIX_FUNCTIONS
@eval begin
fa = $f($a)
@test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a))))
@test fa.a isa Eye
end
end

fa = inv(a)
@test collect(fa) ≈ inv(collect(a))
@test fa.a isa Eye

fa = pinv(a)
@test collect(fa) ≈ pinv(collect(a))
@test fa.a isa Eye

@test det(a) ≈ det(collect(a))

# A ⊗ Eye
rng = StableRNG(123)
a = randn(rng, 3, 3) ⊗ Eye(2)
for f in setdiff(MATRIX_FUNCTIONS, [:atanh])
@eval begin
fa = $f($a)
@test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a))))
@test fa.b isa Eye
end
end

fa = inv(a)
@test collect(fa) ≈ inv(collect(a))
@test fa.b isa Eye

fa = pinv(a)
@test collect(fa) ≈ pinv(collect(a))
@test fa.b isa Eye

@test det(a) ≈ det(collect(a))

# Eye ⊗ Eye
a = Eye(2) ⊗ Eye(2)
for f in KroneckerArrays.MATRIX_FUNCTIONS
@eval begin
@test_throws ArgumentError $f($a)
end
end

fa = inv(a)
@test fa == a
@test fa.a isa Eye
@test fa.b isa Eye

fa = pinv(a)
@test fa == a
@test fa.a isa Eye
@test fa.b isa Eye

@test det(a) ≈ det(collect(a)) ≈ 1
end
Loading
Loading