1- struct CartesianProduct{A,B}
1+ struct CartesianPair{A,B}
2+ a:: A
3+ b:: B
4+ end
5+ arguments (a:: CartesianPair ) = (a. a, a. b)
6+ arguments (a:: CartesianPair , n:: Int ) = arguments (a)[n]
7+
8+ arg1 (a:: CartesianPair ) = a. a
9+ arg2 (a:: CartesianPair ) = a. b
10+
11+ × (a, b) = CartesianPair (a, b)
12+
13+ function Base. show (io:: IO , a:: CartesianPair )
14+ print (io, a. a, " × " , a. b)
15+ return nothing
16+ end
17+
18+ struct CartesianProduct{TA,TB,A<: AbstractVector{TA} ,B<: AbstractVector{TB} } < :
19+ AbstractVector{CartesianPair{TA,TB}}
220 a:: A
321 b:: B
422end
@@ -13,17 +31,44 @@ function Base.show(io::IO, a::CartesianProduct)
1331 return nothing
1432end
1533
16- × (a, b) = CartesianProduct (a, b)
17- Base. length (a:: CartesianProduct ) = length (a. a) * length (a. b)
18- Base. getindex (a:: CartesianProduct , i:: CartesianProduct ) = a. a[i. a] × a. b[i. b]
34+ # This is used when printing block sparse arrays with KroneckerArray
35+ # blocks.
36+ # TODO : Investigate if this is needed or if it can be avoided
37+ # by iterating over CartesianProduct axes.
38+ function Base. checkindex (:: Type{Bool} , inds:: CartesianProduct , i:: Int )
39+ return checkindex (Bool, Base. OneTo (length (inds)), i)
40+ end
41+
42+ × (a:: AbstractVector , b:: AbstractVector ) = CartesianProduct (a, b)
43+ Base. length (a:: CartesianProduct ) = length (arg1 (a)) * length (arg2 (a))
44+ Base. size (a:: CartesianProduct ) = (length (a),)
45+ function Base. getindex (a:: CartesianProduct , i:: CartesianProduct )
46+ return arg1 (a)[arg1 (i)] × arg2 (a)[arg2 (i)]
47+ end
48+ function Base. getindex (a:: CartesianProduct , i:: CartesianPair )
49+ return arg1 (a)[arg1 (i)] × arg2 (a)[arg2 (i)]
50+ end
51+ function Base. getindex (a:: CartesianProduct , i:: Int )
52+ I = Tuple (CartesianIndices ((length (arg1 (a)), length (arg2 (a))))[i])
53+ return a[I[1 ] × I[2 ]]
54+ end
55+
56+ using Base: promote_shape
57+ function Base. promote_shape (
58+ a:: Tuple{Vararg{CartesianProduct}} , b:: Tuple{Vararg{CartesianProduct}}
59+ )
60+ return promote_shape (arg1 .(a), arg1 .(b)) × promote_shape (arg2 .(a), arg2 .(b))
61+ end
1962
20- function Base. iterate (a:: CartesianProduct , state... )
21- x = iterate (Iterators. product (a. a, a. b), state... )
22- isnothing (x) && return x
23- next, new_state = x
24- return × (next... ), new_state
63+ using Base. Broadcast: axistype
64+ function Base. Broadcast. axistype (r1:: CartesianProduct , r2:: CartesianProduct )
65+ return axistype (arg1 (r1), arg1 (r2)) × axistype (arg2 (r1), arg2 (r2))
2566end
2667
68+ # # function Base.to_index(A::KroneckerArray, I::CartesianProduct)
69+ # # return I
70+ # # end
71+
2772struct CartesianProductUnitRange{T,P<: CartesianProduct ,R<: AbstractUnitRange{T} } < :
2873 AbstractUnitRange{T}
2974 product:: P
@@ -38,27 +83,36 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
3883arg1 (a:: CartesianProductUnitRange ) = arg1 (cartesianproduct (a))
3984arg2 (a:: CartesianProductUnitRange ) = arg2 (cartesianproduct (a))
4085
86+ function Base. show (io:: IO , r:: CartesianProductUnitRange )
87+ print (io, cartesianproduct (r), " : " , unproduct (r))
88+ return nothing
89+ end
90+ function Base. show (io:: IO , mime:: MIME"text/plain" , r:: CartesianProductUnitRange )
91+ show (io, mime, cartesianproduct (r))
92+ println (io)
93+ show (io, mime, unproduct (r))
94+ return nothing
95+ end
96+
4197function CartesianProductUnitRange (p:: CartesianProduct )
4298 return CartesianProductUnitRange (p, Base. OneTo (length (p)))
4399end
44100function CartesianProductUnitRange (a, b)
45101 return CartesianProductUnitRange (a × b)
46102end
47- to_range (a:: AbstractUnitRange ) = a
48- to_range (i:: Integer ) = Base. OneTo (i)
49- cartesianrange (a, b) = cartesianrange (to_range (a) × to_range (b))
103+ to_product_indices (a:: AbstractVector ) = a
104+ to_product_indices (i:: Integer ) = Base. OneTo (i)
105+ cartesianrange (a, b) = cartesianrange (to_product_indices (a) × to_product_indices (b))
50106function cartesianrange (p:: CartesianProduct )
51- p′ = to_range (p . a) × to_range (p . b )
107+ p′ = to_product_indices ( arg1 (p)) × to_product_indices ( arg2 (p) )
52108 return cartesianrange (p′, Base. OneTo (length (p′)))
53109end
54110function cartesianrange (p:: CartesianProduct , range:: AbstractUnitRange )
55- p′ = to_range (p . a) × to_range (p . b )
111+ p′ = to_product_indices ( arg1 (p)) × to_product_indices ( arg2 (p) )
56112 return CartesianProductUnitRange (p′, range)
57113end
58114
59- function Base. axes (r:: CartesianProductUnitRange )
60- return (CartesianProductUnitRange (r. product, only (axes (r. range))),)
61- end
115+ Base. axes (r:: CartesianProductUnitRange ) = (cartesianrange (cartesianproduct (r)),)
62116
63117using Base. Broadcast: DefaultArrayStyle
64118for f in (:+ , :- )
@@ -84,3 +138,7 @@ function Base.Broadcast.axistype(
84138 range = axistype (unproduct (r1), unproduct (r2))
85139 return cartesianrange (prod, range)
86140end
141+
142+ function Base. checkindex (:: Type{Bool} , inds:: CartesianProductUnitRange , i:: CartesianPair )
143+ return checkindex (Bool, arg1 (inds), arg1 (i)) && checkindex (Bool, arg2 (inds), arg2 (i))
144+ end
0 commit comments