Skip to content

Commit 042233f

Browse files
committed
Introduce CartesianPair
1 parent 1cd2ce3 commit 042233f

File tree

9 files changed

+188
-92
lines changed

9 files changed

+188
-92
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.18"
4+
version = "0.1.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,53 @@
11
module KroneckerArraysBlockSparseArraysExt
22

3-
using BlockSparseArrays: BlockSparseArrays, blockrange
4-
using KroneckerArrays: CartesianProduct, cartesianrange
3+
using BlockArrays: Block
4+
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
5+
using KroneckerArrays: CartesianPair, CartesianProduct
6+
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
7+
return GenericBlockIndex(b, (I1, Irest...))
8+
end
9+
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
10+
return BlockIndexVector(b, (I1, Irest...))
11+
end
12+
13+
using BlockSparseArrays: BlockSparseArrays, BlockUnitRange, blockrange
14+
using KroneckerArrays: CartesianPair, CartesianProduct, ×, cartesianrange
15+
516
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
6-
return blockrange(map(cartesianrange, bs))
17+
return blockrange(cartesianrange.(bs))
18+
end
19+
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair{<:Integer,<:Integer}})
20+
bs′ = map(bs) do b
21+
return Base.OneTo(arg1(b)) × Base.OneTo(arg2(b))
22+
end
23+
return blockrange(bs′)
24+
end
25+
26+
using BlockSparseArrays: BlockSparseArrays, infimum
27+
using KroneckerArrays: cartesianproduct, CartesianProductUnitRange
28+
function BlockSparseArrays.infimum(r1::CartesianProductUnitRange, r2::CartesianProductUnitRange)
29+
return cartesianrange(infimum(cartesianproduct.((r1, r2))...))
30+
end
31+
function BlockSparseArrays.infimum(r1::CartesianProduct, r2::CartesianProduct)
32+
return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2))
33+
end
34+
35+
using BlockArrays: Block
36+
using KroneckerArrays: cartesianrange
37+
function Base.getindex(
38+
r::BlockUnitRange{<:Integer,<:Vector{<:CartesianProduct}}, I::Block{1,Int64}
39+
)
40+
prod = eachblockaxis(r)[Int(I)]
41+
range = r.r[I]
42+
return cartesianrange(prod, range)
43+
end
44+
45+
# Fix ambiguity error with BlockArrays.jl.
46+
using BlockArrays: AbstractBlockArray
47+
function Base.similar(
48+
a::AbstractBlockArray, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
49+
)
50+
return similar(a, eltype(a), axs)
751
end
852

953
using BlockArrays: AbstractBlockedUnitRange

src/cartesianproduct.jl

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
struct CartesianProduct{A,B}
1+
struct CartesianPair{A,B}
2+
a::A
3+
b::B
4+
end
5+
arguments(a::CartesianPair) = (a.a, a.b)
6+
arguments(a::CartesianPair, n::Int) = arguments(a)[n]
7+
8+
arg1(a::CartesianPair) = a.a
9+
arg2(a::CartesianPair) = a.b
10+
11+
×(a, b) = CartesianPair(a, b)
12+
13+
function Base.show(io::IO, a::CartesianPair)
14+
print(io, a.a, " × ", a.b)
15+
return nothing
16+
end
17+
18+
struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
19+
AbstractVector{CartesianPair{TA,TB}}
220
a::A
321
b::B
422
end
@@ -13,17 +31,44 @@ function Base.show(io::IO, a::CartesianProduct)
1331
return nothing
1432
end
1533

16-
×(a, b) = CartesianProduct(a, b)
17-
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
18-
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
34+
# This is used when printing block sparse arrays with KroneckerArray
35+
# blocks.
36+
# TODO: Investigate if this is needed or if it can be avoided
37+
# by iterating over CartesianProduct axes.
38+
function Base.checkindex(::Type{Bool}, inds::CartesianProduct, i::Int)
39+
return checkindex(Bool, Base.OneTo(length(inds)), i)
40+
end
41+
42+
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
43+
Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a))
44+
Base.size(a::CartesianProduct) = (length(a),)
45+
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
46+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
47+
end
48+
function Base.getindex(a::CartesianProduct, i::CartesianPair)
49+
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
50+
end
51+
function Base.getindex(a::CartesianProduct, i::Int)
52+
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
53+
return a[I[1] × I[2]]
54+
end
55+
56+
using Base: promote_shape
57+
function Base.promote_shape(
58+
a::Tuple{Vararg{CartesianProduct}}, b::Tuple{Vararg{CartesianProduct}}
59+
)
60+
return promote_shape(arg1.(a), arg1.(b)) × promote_shape(arg2.(a), arg2.(b))
61+
end
1962

20-
function Base.iterate(a::CartesianProduct, state...)
21-
x = iterate(Iterators.product(a.a, a.b), state...)
22-
isnothing(x) && return x
23-
next, new_state = x
24-
return ×(next...), new_state
63+
using Base.Broadcast: axistype
64+
function Base.Broadcast.axistype(r1::CartesianProduct, r2::CartesianProduct)
65+
return axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2))
2566
end
2667

68+
## function Base.to_index(A::KroneckerArray, I::CartesianProduct)
69+
## return I
70+
## end
71+
2772
struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
2873
AbstractUnitRange{T}
2974
product::P
@@ -38,27 +83,36 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
3883
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
3984
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))
4085

86+
function Base.show(io::IO, r::CartesianProductUnitRange)
87+
print(io, cartesianproduct(r), ": ", unproduct(r))
88+
return nothing
89+
end
90+
function Base.show(io::IO, mime::MIME"text/plain", r::CartesianProductUnitRange)
91+
show(io, mime, cartesianproduct(r))
92+
println(io)
93+
show(io, mime, unproduct(r))
94+
return nothing
95+
end
96+
4197
function CartesianProductUnitRange(p::CartesianProduct)
4298
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
4399
end
44100
function CartesianProductUnitRange(a, b)
45101
return CartesianProductUnitRange(a × b)
46102
end
47-
to_range(a::AbstractUnitRange) = a
48-
to_range(i::Integer) = Base.OneTo(i)
49-
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
103+
to_product_indices(a::AbstractVector) = a
104+
to_product_indices(i::Integer) = Base.OneTo(i)
105+
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
50106
function cartesianrange(p::CartesianProduct)
51-
p′ = to_range(p.a) × to_range(p.b)
107+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
52108
return cartesianrange(p′, Base.OneTo(length(p′)))
53109
end
54110
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
55-
p′ = to_range(p.a) × to_range(p.b)
111+
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
56112
return CartesianProductUnitRange(p′, range)
57113
end
58114

59-
function Base.axes(r::CartesianProductUnitRange)
60-
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
61-
end
115+
Base.axes(r::CartesianProductUnitRange) = (cartesianrange(cartesianproduct(r)),)
62116

63117
using Base.Broadcast: DefaultArrayStyle
64118
for f in (:+, :-)
@@ -84,3 +138,7 @@ function Base.Broadcast.axistype(
84138
range = axistype(unproduct(r1), unproduct(r2))
85139
return cartesianrange(prod, range)
86140
end
141+
142+
function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
143+
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
144+
end

src/fillarrays/kroneckerarray.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
using FillArrays: FillArrays, Zeros
22
function FillArrays.fillsimilar(
3-
a::Zeros{T},
4-
ax::Tuple{
5-
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
6-
},
3+
a::Zeros{T}, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
74
) where {T}
85
return Zeros{T}(arg1.(ax)) Zeros{T}(arg2.(ax))
96
end
107

8+
# Work around that `Zeros` requires `AbstractUnitRange` axes.
9+
function FillArrays.Zeros{T,N}(
10+
ax::Tuple{CartesianProduct,Vararg{CartesianProduct}}
11+
) where {T,N}
12+
return Zeros{T,N}(cartesianslice.(ax))
13+
end
14+
1115
using FillArrays: RectDiagonal, OnesVector
1216
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}
1317

@@ -68,6 +72,8 @@ end
6872
# Like `copy` but preserves `Eye`.
6973
_copy(a::Eye) = a
7074

75+
_getindex(a::Eye, I1::Colon, I2::Colon) = a
76+
7177
using DerivableInterfaces: DerivableInterfaces, zero!
7278
function DerivableInterfaces.zero!(a::EyeKronecker)
7379
zero!(a.b)

src/fillarrays/matrixalgebrakit.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1-
function infimum(r1::AbstractRange, r2::AbstractUnitRange)
1+
function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange)
22
Base.require_one_based_indexing(r1, r2)
33
if length(r1) length(r2)
44
return r1
55
else
66
return r2
77
end
88
end
9-
function supremum(r1::AbstractRange, r2::AbstractUnitRange)
9+
function infimum(r1::CartesianProduct, r2::CartesianProduct)
10+
return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2))
11+
end
12+
function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange)
1013
Base.require_one_based_indexing(r1, r2)
1114
if length(r1) length(r2)
1215
return r1
1316
else
1417
return r2
1518
end
1619
end
20+
function supremum(r1::CartesianProduct, r2::CartesianProduct)
21+
return supremum(arg1(r1), arg1(r2)) × supremum(arg2(r1), arg2(r2))
22+
end
1723

1824
# Allow customization for `Eye`.
1925
_diagview(a::Eye) = parent(a)

src/fillarrays/matrixalgebrakit_truncate.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ function MatrixAlgebraKit.findtruncated(
2424
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
2525
)
2626
I = findtruncated(Vector(values), strategy.strategy)
27-
prods = collect(only(axes(values)).product)[I]
28-
I_data = unique(map(x -> x.a, prods))
27+
prods = only(axes(values))[I]
28+
I_data = unique(arg1.(prods))
2929
# Drop truncations that occur within the identity.
3030
I_data = filter(I_data) do i
31-
return count(x -> x.a == i, prods) == length(values.a)
31+
return count(x -> arg1(x) == i, prods) == length(arg1(values))
3232
end
3333
return (:) × I_data
3434
end
3535
function MatrixAlgebraKit.findtruncated(
3636
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
3737
)
3838
I = findtruncated(Vector(values), strategy.strategy)
39-
prods = collect(only(axes(values)).product)[I]
40-
I_data = unique(map(x -> x.b, prods))
39+
prods = only(axes(values))[I]
40+
I_data = unique(map(x -> arg2(x), prods))
4141
# Drop truncations that occur within the identity.
4242
I_data = filter(I_data) do i
43-
return count(x -> x.b == i, prods) == length(values.b)
43+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
4444
end
4545
return I_data × (:)
4646
end

0 commit comments

Comments
 (0)