Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.18"
version = "0.1.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,53 @@
module KroneckerArraysBlockSparseArraysExt

using BlockSparseArrays: BlockSparseArrays, blockrange
using KroneckerArrays: CartesianProduct, cartesianrange
using BlockArrays: Block
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
using KroneckerArrays: CartesianPair, CartesianProduct
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
return GenericBlockIndex(b, (I1, Irest...))
end
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
return BlockIndexVector(b, (I1, Irest...))
end

using BlockSparseArrays: BlockSparseArrays, BlockUnitRange, blockrange
using KroneckerArrays: CartesianPair, CartesianProduct, ×, cartesianrange

function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
return blockrange(map(cartesianrange, bs))
return blockrange(cartesianrange.(bs))
end
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair{<:Integer,<:Integer}})
bs′ = map(bs) do b
return Base.OneTo(arg1(b)) × Base.OneTo(arg2(b))
end
return blockrange(bs′)
end

using BlockSparseArrays: BlockSparseArrays, infimum
using KroneckerArrays: cartesianproduct, CartesianProductUnitRange
function BlockSparseArrays.infimum(r1::CartesianProductUnitRange, r2::CartesianProductUnitRange)
return cartesianrange(infimum(cartesianproduct.((r1, r2))...))
end
function BlockSparseArrays.infimum(r1::CartesianProduct, r2::CartesianProduct)
return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2))
end

using BlockArrays: Block
using KroneckerArrays: cartesianrange
function Base.getindex(
r::BlockUnitRange{<:Integer,<:Vector{<:CartesianProduct}}, I::Block{1,Int64}
)
prod = eachblockaxis(r)[Int(I)]
range = r.r[I]
return cartesianrange(prod, range)
end

# Fix ambiguity error with BlockArrays.jl.
using BlockArrays: AbstractBlockArray
function Base.similar(
a::AbstractBlockArray, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
)
return similar(a, eltype(a), axs)
end

using BlockArrays: AbstractBlockedUnitRange
Expand Down
92 changes: 75 additions & 17 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
struct CartesianProduct{A,B}
struct CartesianPair{A,B}
a::A
b::B
end
arguments(a::CartesianPair) = (a.a, a.b)
arguments(a::CartesianPair, n::Int) = arguments(a)[n]

arg1(a::CartesianPair) = a.a
arg2(a::CartesianPair) = a.b

×(a, b) = CartesianPair(a, b)

function Base.show(io::IO, a::CartesianPair)
print(io, a.a, " × ", a.b)
return nothing
end

struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
AbstractVector{CartesianPair{TA,TB}}
a::A
b::B
end
Expand All @@ -13,17 +31,44 @@ function Base.show(io::IO, a::CartesianProduct)
return nothing
end

×(a, b) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
# This is used when printing block sparse arrays with KroneckerArray
# blocks.
# TODO: Investigate if this is needed or if it can be avoided
# by iterating over CartesianProduct axes.
function Base.checkindex(::Type{Bool}, inds::CartesianProduct, i::Int)
return checkindex(Bool, Base.OneTo(length(inds)), i)
end

×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a))
Base.size(a::CartesianProduct) = (length(a),)
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::CartesianPair)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::Int)
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
return a[I[1] × I[2]]
end

using Base: promote_shape
function Base.promote_shape(
a::Tuple{Vararg{CartesianProduct}}, b::Tuple{Vararg{CartesianProduct}}
)
return promote_shape(arg1.(a), arg1.(b)) × promote_shape(arg2.(a), arg2.(b))
end

function Base.iterate(a::CartesianProduct, state...)
x = iterate(Iterators.product(a.a, a.b), state...)
isnothing(x) && return x
next, new_state = x
return ×(next...), new_state
using Base.Broadcast: axistype
function Base.Broadcast.axistype(r1::CartesianProduct, r2::CartesianProduct)
return axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2))
end

## function Base.to_index(A::KroneckerArray, I::CartesianProduct)
## return I
## end

struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
AbstractUnitRange{T}
product::P
Expand All @@ -38,27 +83,36 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function Base.show(io::IO, r::CartesianProductUnitRange)
print(io, cartesianproduct(r), ": ", unproduct(r))
return nothing
end
function Base.show(io::IO, mime::MIME"text/plain", r::CartesianProductUnitRange)
show(io, mime, cartesianproduct(r))
println(io)
show(io, mime, unproduct(r))
return nothing
end

function CartesianProductUnitRange(p::CartesianProduct)
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
end
function CartesianProductUnitRange(a, b)
return CartesianProductUnitRange(a × b)
end
to_range(a::AbstractUnitRange) = a
to_range(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
to_product_indices(a::AbstractVector) = a
to_product_indices(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
function cartesianrange(p::CartesianProduct)
p′ = to_range(p.a) × to_range(p.b)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return cartesianrange(p′, Base.OneTo(length(p′)))
end
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
p′ = to_range(p.a) × to_range(p.b)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return CartesianProductUnitRange(p′, range)
end

function Base.axes(r::CartesianProductUnitRange)
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
end
Base.axes(r::CartesianProductUnitRange) = (cartesianrange(cartesianproduct(r)),)

using Base.Broadcast: DefaultArrayStyle
for f in (:+, :-)
Expand All @@ -84,3 +138,7 @@ function Base.Broadcast.axistype(
range = axistype(unproduct(r1), unproduct(r2))
return cartesianrange(prod, range)
end

function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
end
14 changes: 10 additions & 4 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
using FillArrays: FillArrays, Zeros
function FillArrays.fillsimilar(
a::Zeros{T},
ax::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
a::Zeros{T}, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
) where {T}
return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax))
end

# Work around that `Zeros` requires `AbstractUnitRange` axes.
function FillArrays.Zeros{T,N}(
ax::Tuple{CartesianProduct,Vararg{CartesianProduct}}
) where {T,N}
return Zeros{T,N}(cartesianslice.(ax))
end

using FillArrays: RectDiagonal, OnesVector
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}

Expand Down Expand Up @@ -68,6 +72,8 @@ end
# Like `copy` but preserves `Eye`.
_copy(a::Eye) = a

_getindex(a::Eye, I1::Colon, I2::Colon) = a

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::EyeKronecker)
zero!(a.b)
Expand Down
10 changes: 8 additions & 2 deletions src/fillarrays/matrixalgebrakit.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
function infimum(r1::AbstractRange, r2::AbstractUnitRange)
function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange)
Base.require_one_based_indexing(r1, r2)
if length(r1) ≤ length(r2)
return r1
else
return r2
end
end
function supremum(r1::AbstractRange, r2::AbstractUnitRange)
function infimum(r1::CartesianProduct, r2::CartesianProduct)
return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2))
end
function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange)
Base.require_one_based_indexing(r1, r2)
if length(r1) ≥ length(r2)
return r1
else
return r2
end
end
function supremum(r1::CartesianProduct, r2::CartesianProduct)
return supremum(arg1(r1), arg1(r2)) × supremum(arg2(r1), arg2(r2))
end

# Allow customization for `Eye`.
_diagview(a::Eye) = parent(a)
Expand Down
12 changes: 6 additions & 6 deletions src/fillarrays/matrixalgebrakit_truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ function MatrixAlgebraKit.findtruncated(
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.a, prods))
prods = only(axes(values))[I]
I_data = unique(arg1.(prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.a == i, prods) == length(values.a)
return count(x -> arg1(x) == i, prods) == length(arg1(values))
end
return (:) × I_data
end
function MatrixAlgebraKit.findtruncated(
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.b, prods))
prods = only(axes(values))[I]
I_data = unique(map(x -> arg2(x), prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.b == i, prods) == length(values.b)
return count(x -> arg2(x) == i, prods) == length(arg2(values))
end
return I_data × (:)
end
Expand Down
Loading
Loading