Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.18"
version = "0.1.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
[compat]
Adapt = "4.3.0"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.19"
BlockSparseArrays = "0.7.20"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
module KroneckerArraysBlockSparseArraysExt

using BlockArrays: Block
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
using KroneckerArrays: CartesianPair, CartesianProduct
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
return GenericBlockIndex(b, (I1, Irest...))
end
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
return BlockIndexVector(b, (I1, Irest...))
end

using BlockSparseArrays: BlockSparseArrays, blockrange
using KroneckerArrays: CartesianProduct, cartesianrange
using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair})
return blockrange(map(cartesianrange, bs))
end
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
return blockrange(map(cartesianrange, bs))
end
Expand Down
66 changes: 50 additions & 16 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
struct CartesianProduct{A,B}
struct CartesianPair{A,B}
a::A
b::B
end
arguments(a::CartesianPair) = (a.a, a.b)
arguments(a::CartesianPair, n::Int) = arguments(a)[n]

arg1(a::CartesianPair) = a.a
arg2(a::CartesianPair) = a.b

×(a, b) = CartesianPair(a, b)

function Base.show(io::IO, a::CartesianPair)
print(io, a.a, " × ", a.b)
return nothing
end

struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
AbstractVector{CartesianPair{TA,TB}}
a::A
b::B
end
Expand All @@ -13,15 +31,19 @@ function Base.show(io::IO, a::CartesianProduct)
return nothing
end

×(a, b) = CartesianProduct(a, b)
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
Base.size(a::CartesianProduct) = (length(a),)

function Base.iterate(a::CartesianProduct, state...)
x = iterate(Iterators.product(a.a, a.b), state...)
isnothing(x) && return x
next, new_state = x
return ×(next...), new_state
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::CartesianPair)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::Int)
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
return a[I[1] × I[2]]
end

struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
Expand All @@ -44,20 +66,32 @@ end
function CartesianProductUnitRange(a, b)
return CartesianProductUnitRange(a × b)
end
to_range(a::AbstractUnitRange) = a
to_range(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
to_product_indices(a::AbstractUnitRange) = a
to_product_indices(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
function cartesianrange(p::CartesianPair)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return cartesianrange(p′)
end
function cartesianrange(p::CartesianProduct)
p′ = to_range(p.a) × to_range(p.b)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return cartesianrange(p′, Base.OneTo(length(p′)))
end
function cartesianrange(p::CartesianPair, range::AbstractUnitRange)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return cartesianrange(p′, range)
end
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
p′ = to_range(p.a) × to_range(p.b)
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
return CartesianProductUnitRange(p′, range)
end

function Base.axes(r::CartesianProductUnitRange)
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
end

function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
end

using Base.Broadcast: DefaultArrayStyle
Expand All @@ -66,12 +100,12 @@ for f in (:+, :-)
function Broadcast.broadcasted(
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
)
return CartesianProductUnitRange(r.product, $f.(r.range, x))
return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x))
end
function Broadcast.broadcasted(
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
)
return CartesianProductUnitRange(r.product, $f.(x, r.range))
return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r)))
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a

Expand Down
12 changes: 6 additions & 6 deletions src/fillarrays/matrixalgebrakit_truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ function MatrixAlgebraKit.findtruncated(
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.a, prods))
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(arg1, prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.a == i, prods) == length(values.a)
return count(x -> arg1(x) == i, prods) == length(arg1(values))
end
return (:) × I_data
end
function MatrixAlgebraKit.findtruncated(
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(only(axes(values)).product)[I]
I_data = unique(map(x -> x.b, prods))
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(x -> arg2(x), prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> x.b == i, prods) == length(values.b)
return count(x -> arg2(x) == i, prods) == length(arg2(values))
end
return I_data × (:)
end
Expand Down
Loading
Loading