Skip to content

Commit fd48bed

Browse files
authored
Introduce CartesianPair (#25)
1 parent 1cd2ce3 commit fd48bed

File tree

8 files changed

+132
-84
lines changed

8 files changed

+132
-84
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.18"
4+
version = "0.1.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3.0"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.7.19"
26+
BlockSparseArrays = "0.7.20"
2727
DerivableInterfaces = "0.5.0"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
module KroneckerArraysBlockSparseArraysExt
22

3+
using BlockArrays: Block
4+
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
5+
using KroneckerArrays: CartesianPair, CartesianProduct
6+
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
7+
return GenericBlockIndex(b, (I1, Irest...))
8+
end
9+
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
10+
return BlockIndexVector(b, (I1, Irest...))
11+
end
12+
313
using BlockSparseArrays: BlockSparseArrays, blockrange
4-
using KroneckerArrays: CartesianProduct, cartesianrange
14+
using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
15+
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair})
16+
return blockrange(map(cartesianrange, bs))
17+
end
518
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
619
return blockrange(map(cartesianrange, bs))
720
end

src/cartesianproduct.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
struct CartesianProduct{A,B}
1+
struct CartesianPair{A,B}
2+
a::A
3+
b::B
4+
end
5+
arguments(a::CartesianPair) = (a.a, a.b)
6+
arguments(a::CartesianPair, n::Int) = arguments(a)[n]
7+
8+
arg1(a::CartesianPair) = a.a
9+
arg2(a::CartesianPair) = a.b
10+
11+
×(a, b) = CartesianPair(a, b)
12+
13+
function Base.show(io::IO, a::CartesianPair)
14+
print(io, a.a, " × ", a.b)
15+
return nothing
16+
end
17+
18+
struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
19+
AbstractVector{CartesianPair{TA,TB}}
220
a::A
321
b::B
422
end
@@ -13,15 +31,19 @@ function Base.show(io::IO, a::CartesianProduct)
1331
return nothing
1432
end
1533

16-
×(a, b) = CartesianProduct(a, b)
34+
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
1735
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
18-
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
36+
Base.size(a::CartesianProduct) = (length(a),)
1937

20-
function Base.iterate(a::CartesianProduct, state...)
21-
x = iterate(Iterators.product(a.a, a.b), state...)
22-
isnothing(x) && return x
23-
next, new_state = x
24-
return ×(next...), new_state
38+
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
39+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
40+
end
41+
function Base.getindex(a::CartesianProduct, i::CartesianPair)
42+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
43+
end
44+
function Base.getindex(a::CartesianProduct, i::Int)
45+
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
46+
return a[I[1] × I[2]]
2547
end
2648

2749
struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
@@ -44,20 +66,32 @@ end
4466
function CartesianProductUnitRange(a, b)
4567
return CartesianProductUnitRange(a × b)
4668
end
47-
to_range(a::AbstractUnitRange) = a
48-
to_range(i::Integer) = Base.OneTo(i)
49-
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
69+
to_product_indices(a::AbstractUnitRange) = a
70+
to_product_indices(i::Integer) = Base.OneTo(i)
71+
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
72+
function cartesianrange(p::CartesianPair)
73+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
74+
return cartesianrange(p′)
75+
end
5076
function cartesianrange(p::CartesianProduct)
51-
p′ = to_range(p.a) × to_range(p.b)
77+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
5278
return cartesianrange(p′, Base.OneTo(length(p′)))
5379
end
80+
function cartesianrange(p::CartesianPair, range::AbstractUnitRange)
81+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
82+
return cartesianrange(p′, range)
83+
end
5484
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
55-
p′ = to_range(p.a) × to_range(p.b)
85+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
5686
return CartesianProductUnitRange(p′, range)
5787
end
5888

5989
function Base.axes(r::CartesianProductUnitRange)
60-
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
90+
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
91+
end
92+
93+
function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
94+
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
6195
end
6296

6397
using Base.Broadcast: DefaultArrayStyle
@@ -66,12 +100,12 @@ for f in (:+, :-)
66100
function Broadcast.broadcasted(
67101
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
68102
)
69-
return CartesianProductUnitRange(r.product, $f.(r.range, x))
103+
return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x))
70104
end
71105
function Broadcast.broadcasted(
72106
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
73107
)
74-
return CartesianProductUnitRange(r.product, $f.(x, r.range))
108+
return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r)))
75109
end
76110
end
77111
end

src/fillarrays/kroneckerarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
2121
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2222
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2323

24+
_getindex(a::Eye, I1::Colon, I2::Colon) = a
25+
2426
# Like `adapt` but preserves `Eye`.
2527
_adapt(to, a::Eye) = a
2628

src/fillarrays/matrixalgebrakit_truncate.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ function MatrixAlgebraKit.findtruncated(
2424
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
2525
)
2626
I = findtruncated(Vector(values), strategy.strategy)
27-
prods = collect(only(axes(values)).product)[I]
28-
I_data = unique(map(x -> x.a, prods))
27+
prods = collect(cartesianproduct(only(axes(values))))[I]
28+
I_data = unique(map(arg1, prods))
2929
# Drop truncations that occur within the identity.
3030
I_data = filter(I_data) do i
31-
return count(x -> x.a == i, prods) == length(values.a)
31+
return count(x -> arg1(x) == i, prods) == length(arg1(values))
3232
end
3333
return (:) × I_data
3434
end
3535
function MatrixAlgebraKit.findtruncated(
3636
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
3737
)
3838
I = findtruncated(Vector(values), strategy.strategy)
39-
prods = collect(only(axes(values)).product)[I]
40-
I_data = unique(map(x -> x.b, prods))
39+
prods = collect(cartesianproduct(only(axes(values))))[I]
40+
I_data = unique(map(x -> arg2(x), prods))
4141
# Drop truncations that occur within the identity.
4242
I_data = filter(I_data) do i
43-
return count(x -> x.b == i, prods) == length(values.b)
43+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
4444
end
4545
return I_data × (:)
4646
end

0 commit comments

Comments
 (0)