Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.21"
version = "0.1.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
[compat]
Adapt = "4.3.0"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.21"
BlockSparseArrays = "0.7.22"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)(
return error("Not implemented.")
end

using BlockSparseArrays: BlockSparseArrays
using KroneckerArrays: KroneckerArrays, KroneckerVector
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
return KroneckerArrays.to_truncated_indices(values, I)
end

end
72 changes: 66 additions & 6 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
arg1(a::CartesianProduct) = a.a
arg2(a::CartesianProduct) = a.b

Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))

function Base.show(io::IO, a::CartesianProduct)
print(io, a.a, " × ", a.b)
return nothing
end
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
show(io, a)
return nothing
end

×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
Expand All @@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::Int)
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
return a[I[1] × I[2]]
I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i])
return a[I[2] × I[1]]
end

struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <:
AbstractVector{T}
product::P
values::V
end
cartesianproduct(r::CartesianProductVector) = getfield(r, :product)
unproduct(r::CartesianProductVector) = getfield(r, :values)
Base.length(a::CartesianProductVector) = length(unproduct(a))
Base.size(a::CartesianProductVector) = (length(a),)
function Base.axes(r::CartesianProductVector)
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
end
function Base.copy(a::CartesianProductVector)
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))
end
function Base.getindex(r::CartesianProductVector, i::Integer)
return unproduct(r)[i]
end

function Base.show(io::IO, a::CartesianProductVector)
show(io, unproduct(a))
return nothing
end
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector)
show(io, mime, cartesianproduct(a))
println(io)
show(io, mime, unproduct(a))
return nothing
end

struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
Expand All @@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function Base.show(io::IO, a::CartesianProductUnitRange)
show(io, unproduct(a))
return nothing
end
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange)
show(io, mime, cartesianproduct(a))
println(io)
show(io, mime, unproduct(a))
return nothing
end

function CartesianProductUnitRange(p::CartesianProduct)
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
end
function CartesianProductUnitRange(a, b)
return CartesianProductUnitRange(a × b)
end
to_product_indices(a::AbstractUnitRange) = a
to_product_indices(a::AbstractVector) = a
to_product_indices(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
function cartesianrange(p::CartesianPair)
Expand All @@ -87,17 +134,30 @@ function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
end

function Base.axes(r::CartesianProductUnitRange)
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
prod = cartesianproduct(r)
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
end

function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
end

const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{T,P,R}
Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,)
Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices
Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,)

function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct)
prod = cartesianproduct(a)
prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)]
return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I))
end

# Reverse map from CartesianPair to linear index in the range.
function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair)
i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds)))
return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]]
i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds)))
return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]]
end

using Base.Broadcast: DefaultArrayStyle
Expand Down
6 changes: 6 additions & 0 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a
_view(a::Eye, I1::Colon, I2::Colon) = a
_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
_view(a::Eye, I1::Colon, I2::Base.Slice) = a

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a
Expand Down
42 changes: 24 additions & 18 deletions src/fillarrays/matrixalgebrakit_truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}

function MatrixAlgebraKit.findtruncated(
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(arg1, prods))
axis(a) = only(axes(a))

# Convert indices determined with a generic call to `findtruncated` to indices
# more suited for a KroneckerVector.
function to_truncated_indices(values::OnesKroneckerVector, I)
prods = cartesianproduct(axis(values))[I]
I_id = only(to_indices(arg1(values), (:,)))
I_data = unique(arg2.(prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> arg1(x) == i, prods) == length(arg1(values))
return count(x -> arg2(x) == i, prods) == length(arg2(values))
end
return (:) × I_data
return I_id × I_data
end
function MatrixAlgebraKit.findtruncated(
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(x -> arg2(x), prods))
function to_truncated_indices(values::KroneckerOnesVector, I)
#I = findtruncated(Vector(values), strategy.strategy)
prods = cartesianproduct(axis(values))[I]
I_data = unique(arg1.(prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> arg2(x) == i, prods) == length(arg2(values))
return count(x -> arg1(x) == i, prods) == length(arg2(values))
end
return I_data × (:)
I_id = only(to_indices(arg2(values), (:,)))
return I_data × I_id
end
function to_truncated_indices(values::OnesVectorOnesVector, I)
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
end

function MatrixAlgebraKit.findtruncated(
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
values::KroneckerVector, strategy::KroneckerTruncationStrategy
)
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
I = findtruncated(Vector(values), strategy.strategy)
return to_truncated_indices(values, I)
end

for f in [:eig_trunc!, :eigh_trunc!]
Expand Down
38 changes: 28 additions & 10 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
return dest
end

function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B}
return KroneckerArray(convert(A, arg1(a)), convert(B, arg2(a)))
end

# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
return similar(a, elt, axs)
Expand Down Expand Up @@ -167,20 +171,36 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {
return a[I′...]
end

# Indexing logic.
function Base.to_indices(
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg}
)
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
return I1 .× I2
end

# Allow customizing for `FillArrays.Eye`.
_getindex(a::AbstractArray, I...) = a[I...]
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
end
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
function Base.getindex(
a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}
) where {N}
I′ = to_indices(a, I)
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
end
# Fix ambigiuity error.
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]

# Allow customizing for `FillArrays.Eye`.
_view(a::AbstractArray, I...) = view(a, I...)
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
arg1(::Colon) = (:)
arg2(::Colon) = (:)
arg1(::Base.Slice) = (:)
arg2(::Base.Slice) = (:)
function Base.view(
a::KroneckerArray{<:Any,N},
I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N},
) where {N}
return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...)
end
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
Expand Down Expand Up @@ -263,10 +283,8 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N
return KroneckerStyle{N}(style_a, style_b)
end
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B}
ax_a = arg1.(ax)
ax_b = arg2.(ax)
bc_a = Broadcasted(A, nothing, (), ax_a)
bc_b = Broadcasted(B, nothing, (), ax_b)
bc_a = Broadcasted(A, bc.f, arg1.(bc.args), arg1.(ax))
bc_b = Broadcasted(B, bc.f, arg2.(bc.args), arg2.(ax))
a = similar(bc_a, elt)
b = similar(bc_b, elt)
return a ⊗ b
Expand Down
18 changes: 9 additions & 9 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
p = [1, 2] × [3, 4, 5]
@test length(p) == 6
@test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5]
@test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5]

r = @constinferred cartesianrange(2, 3)
@test r ===
Expand All @@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test first(r) == 1
@test last(r) == 6
@test r[1 × 1] == 1
@test r[2 × 1] == 2
@test r[1 × 2] == 3
@test r[2 × 2] == 4
@test r[1 × 3] == 5
@test r[1 × 2] == 2
@test r[1 × 3] == 3
@test r[2 × 1] == 4
@test r[2 × 2] == 5
@test r[2 × 3] == 6

r = @constinferred(cartesianrange(2 × 3, 2:7))
Expand All @@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test first(r) == 2
@test last(r) == 7
@test r[1 × 1] == 2
@test r[2 × 1] == 3
@test r[1 × 2] == 4
@test r[2 × 2] == 5
@test r[1 × 3] == 6
@test r[1 × 2] == 3
@test r[1 × 3] == 4
@test r[2 × 1] == 5
@test r[2 × 2] == 6
@test r[2 × 3] == 7

# Test high-dimensional materialization.
Expand Down
34 changes: 30 additions & 4 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using Adapt: adapt
using BlockArrays: Block, BlockRange, mortar
using BlockSparseArrays:
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
using FillArrays: Eye, SquareEye
using JLArrays: JLArray
using KroneckerArrays: KroneckerArray, ⊗, ×
using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2
using LinearAlgebra: norm
using MatrixAlgebraKit: svd_compact
using Test: @test, @test_broken, @testset
Expand Down Expand Up @@ -48,7 +48,18 @@ arrayts = (Array, JLArray)
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
I = [I1, I2]
b = a[I, I]
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
@test iszero(b[Block(2, 1)])
@test iszero(b[Block(1, 2)])
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]

# Slicing
r = blockrange([2 × 2, 3 × 3])
Expand Down Expand Up @@ -159,7 +170,22 @@ end
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
I = [I1, I2]
b = a[I, I]
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
@test arg1(b[Block(1, 1)]) isa Eye
@test iszero(b[Block(2, 1)])
@test arg1(b[Block(2, 1)]) isa Eye
@test iszero(b[Block(1, 2)])
@test arg1(b[Block(1, 2)]) isa Eye
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
@test arg1(b[Block(2, 2)]) isa Eye

# Slicing
r = blockrange([2 × 2, 3 × 3])
Expand Down