Skip to content

Commit 09491bd

Browse files
committed
Introduce CartesianPair
1 parent 1cd2ce3 commit 09491bd

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.18"
4+
version = "0.1.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3.0"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.7.19"
26+
BlockSparseArrays = "0.7.20"
2727
DerivableInterfaces = "0.5.0"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
module KroneckerArraysBlockSparseArraysExt
22

3+
using BlockArrays: Block
4+
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
5+
using KroneckerArrays: CartesianPair, CartesianProduct
6+
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
7+
return GenericBlockIndex(b, (I1, Irest...))
8+
end
9+
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
10+
return BlockIndexVector(b, (I1, Irest...))
11+
end
12+
313
using BlockSparseArrays: BlockSparseArrays, blockrange
4-
using KroneckerArrays: CartesianProduct, cartesianrange
14+
using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
15+
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair})
16+
return blockrange(map(cartesianrange, bs))
17+
end
518
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
619
return blockrange(map(cartesianrange, bs))
720
end

src/cartesianproduct.jl

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
struct CartesianProduct{A,B}
1+
struct CartesianPair{A,B}
2+
a::A
3+
b::B
4+
end
5+
arguments(a::CartesianPair) = (a.a, a.b)
6+
arguments(a::CartesianPair, n::Int) = arguments(a)[n]
7+
8+
arg1(a::CartesianPair) = a.a
9+
arg2(a::CartesianPair) = a.b
10+
11+
×(a, b) = CartesianPair(a, b)
12+
13+
function Base.show(io::IO, a::CartesianPair)
14+
print(io, a.a, " × ", a.b)
15+
return nothing
16+
end
17+
18+
struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
19+
AbstractVector{CartesianPair{TA,TB}}
220
a::A
321
b::B
422
end
@@ -13,15 +31,19 @@ function Base.show(io::IO, a::CartesianProduct)
1331
return nothing
1432
end
1533

16-
×(a, b) = CartesianProduct(a, b)
34+
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
1735
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
18-
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
36+
Base.size(a::CartesianProduct) = (length(a),)
1937

20-
function Base.iterate(a::CartesianProduct, state...)
21-
x = iterate(Iterators.product(a.a, a.b), state...)
22-
isnothing(x) && return x
23-
next, new_state = x
24-
return ×(next...), new_state
38+
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
39+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
40+
end
41+
function Base.getindex(a::CartesianProduct, i::CartesianPair)
42+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
43+
end
44+
function Base.getindex(a::CartesianProduct, i::Int)
45+
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
46+
return a[I[1] × I[2]]
2547
end
2648

2749
struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
@@ -47,10 +69,18 @@ end
4769
to_range(a::AbstractUnitRange) = a
4870
to_range(i::Integer) = Base.OneTo(i)
4971
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
72+
function cartesianrange(p::CartesianPair)
73+
p′ = to_range(p.a) × to_range(p.b)
74+
return cartesianrange(p′)
75+
end
5076
function cartesianrange(p::CartesianProduct)
5177
p′ = to_range(p.a) × to_range(p.b)
5278
return cartesianrange(p′, Base.OneTo(length(p′)))
5379
end
80+
function cartesianrange(p::CartesianPair, range::AbstractUnitRange)
81+
p′ = to_range(p.a) × to_range(p.b)
82+
return cartesianrange(p′, range)
83+
end
5484
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
5585
p′ = to_range(p.a) × to_range(p.b)
5686
return CartesianProductUnitRange(p′, range)

src/kroneckerarray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ end
188188
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
189189
return a.a[map(Base.Fix2(getfield, :a), I)...] a.b[map(Base.Fix2(getfield, :b), I)...]
190190
end
191+
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
192+
return a.a[map(Base.Fix2(getfield, :a), I)...] a.b[map(Base.Fix2(getfield, :b), I)...]
193+
end
191194
# Fix ambigiuity error.
192195
Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[]
193196

0 commit comments

Comments
 (0)