@@ -21,77 +21,146 @@ trivial_axis(::Tuple{}) = Base.OneTo(1)
2121trivial_axis (:: Tuple{Vararg{AbstractUnitRange}} ) = Base. OneTo (1 )
2222trivial_axis (:: Tuple{Vararg{AbstractBlockedUnitRange}} ) = blockedrange ([1 ])
2323
24+ # Inner version takes a list of sub-permutations, overload this one if needed.
2425function fuseaxes (
25- axes:: Tuple{Vararg{AbstractUnitRange}} , blockedperm :: AbstractBlockPermutation
26+ axes:: Tuple{Vararg{AbstractUnitRange}} , lengths :: Val...
2627 )
27- axesblocks = blocks (axes[blockedperm ])
28+ axesblocks = blocks (axes[blockedtrivialperm (lengths) ])
2829 return map (block -> isempty (block) ? trivial_axis (axes) : ⊗ (block... ), axesblocks)
2930end
3031
32+ # Inner version takes a list of sub-permutations, overload this one if needed.
33+ function fuseaxes (
34+ axes:: Tuple{Vararg{AbstractUnitRange}} , permblocks:: Tuple{Vararg{Int}} ...
35+ )
36+ axes′ = map (d -> axes[d], permmortar (permblocks))
37+ return fuseaxes (axes′, Val .(length .(permblocks))... )
38+ end
39+
40+ function fuseaxes (
41+ axes:: Tuple{Vararg{AbstractUnitRange}} , blockedperm:: AbstractBlockPermutation
42+ )
43+ return fuseaxes (axes, blocks (blockedperm)... )
44+ end
45+
46+ # Inner version takes a list of sub-permutations, overload this one if needed.
47+ function permuteblockeddims (a:: AbstractArray , perm1, perm2)
48+ return _permutedims (a, (perm1... , perm2... ))
49+ end
50+ function permuteblockeddims! (a_dest:: AbstractArray , a_src:: AbstractArray , perm1, perm2)
51+ return _permutedims! (a_dest, a_src, (perm1... , perm2... ))
52+ end
53+
3154# TODO remove _permutedims once support for Julia 1.10 is dropped
3255# define permutedims with a BlockedPermuation. Default is to flatten it.
33- function permuteblockeddims (a:: AbstractArray , biperm:: AbstractBlockPermutation )
34- return _permutedims (a, Tuple (biperm))
56+ function permuteblockeddims (a:: AbstractArray , biperm:: AbstractBlockPermutation{2} )
57+ return permuteblockeddims (a, blocks (biperm)... )
3558end
36-
3759function permuteblockeddims! (
38- a :: AbstractArray , b :: AbstractArray , biperm:: AbstractBlockPermutation
60+ a_dest :: AbstractArray , a_src :: AbstractArray , biperm:: AbstractBlockPermutation{2}
3961 )
40- return _permutedims! (a, b, Tuple (biperm))
62+ return permuteblockeddims! (a_dest, a_src, blocks (biperm)... )
4163end
4264
4365# ===================================== matricize ========================================
4466# TBD settle copy/not copy convention
4567# matrix factorizations assume copy
4668# maybe: copy=false kwarg
4769
48- function matricize (a:: AbstractArray , biperm_dest:: AbstractBlockPermutation{2} )
49- ndims (a) == length (biperm_dest) || throw (ArgumentError (" Invalid bipermutation" ))
50- return matricize (FusionStyle (a), a, biperm_dest)
70+ function matricize (a:: AbstractArray , length1:: Val , length2:: Val )
71+ return matricize (FusionStyle (a), a, length1, length2)
72+ end
73+ # This is the primary function that should be overloaded for new fusion styles.
74+ # This assumes the permutation was already performed.
75+ function matricize (style:: FusionStyle , a:: AbstractArray , length1:: Val , length2:: Val )
76+ return throw (
77+ MethodError (
78+ matricize, Tuple{typeof (style), typeof (a), typeof (length1), typeof (length2)}
79+ )
80+ )
5181end
5282
5383function matricize (
54- style :: FusionStyle , a :: AbstractArray , biperm_dest :: AbstractBlockPermutation{2 }
84+ a :: AbstractArray , permblock1 :: Tuple{Vararg{Int}} , permblock2 :: Tuple{Vararg{Int} }
5585 )
56- a_perm = permuteblockeddims (a, biperm_dest)
57- return matricize (style, a_perm, trivialperm (biperm_dest))
86+ return matricize (FusionStyle (a), a, permblock1, permblock2)
5887end
59-
88+ # This is a more advanced version to overload where the permutation is actually performed.
6089function matricize (
61- style:: FusionStyle , a:: AbstractArray , biperm_dest:: BlockedTrivialPermutation{2}
90+ style:: FusionStyle , a:: AbstractArray ,
91+ permblock1:: NTuple{N1, Int} , permblock2:: NTuple{N2, Int}
92+ ) where {N1, N2}
93+ ndims (a) == length (permblock1) + length (permblock2) ||
94+ throw (ArgumentError (" Invalid bipermutation" ))
95+ a_perm = permuteblockeddims (a, permblock1, permblock2)
96+ return matricize (style, a_perm, Val (length (permblock1)), Val (length (permblock2)))
97+ end
98+
99+ # Process inputs such as `EllipsisNotation.Ellipsis`.
100+ function to_permblocks (a:: AbstractArray , permblocks:: NTuple{2, Tuple{Vararg{Int}}} )
101+ isperm ((permblocks[1 ]. .. , permblocks[2 ]. .. )) ||
102+ throw (ArgumentError (" Invalid bipermutation" ))
103+ return permblocks
104+ end
105+ # Like `setcomplement` is like `setdiff` but assumes t2 ⊆ t1.
106+ function tuplesetcomplement (t1:: NTuple{N1} , t2:: NTuple{N2} ) where {N1, N2}
107+ t2 ⊆ t1 || throw (ArgumentError (" t2 must be a subset of t1" ))
108+ return NTuple {N1 - N2} (setdiff (t1, t2))
109+ end
110+ function to_permblocks (
111+ a:: AbstractArray , permblocks:: Tuple{Tuple{Ellipsis}, Tuple{Vararg{Int}}}
112+ )
113+ permblocks1 = tuplesetcomplement (ntuple (identity, ndims (a)), permblocks[2 ])
114+ return (permblocks1, permblocks[2 ])
115+ end
116+ function to_permblocks (
117+ a:: AbstractArray , permblocks:: Tuple{Tuple{Vararg{Int}}, Tuple{Ellipsis}}
62118 )
63- return throw (MethodError (matricize, Tuple{typeof (style), typeof (a), typeof (biperm_dest)}))
119+ permblocks2 = tuplesetcomplement (ntuple (identity, ndims (a)), permblocks[1 ])
120+ return (permblocks[1 ], permblocks2)
121+ end
122+ function matricize (a:: AbstractArray , permblock1, permblock2)
123+ return matricize (FusionStyle (a), a, permblock1, permblock2)
124+ end
125+ function matricize (style:: FusionStyle , a:: AbstractArray , permblock1, permblock2)
126+ return matricize (style, a, to_permblocks (a, (permblock1, permblock2))... )
64127end
65128
66- # default is reshape
129+ function matricize (a:: AbstractArray , biperm_dest:: AbstractBlockPermutation{2} )
130+ return matricize (FusionStyle (a), a, biperm_dest)
131+ end
67132function matricize (
68- :: ReshapeFusion , a:: AbstractArray , biperm_dest:: BlockedTrivialPermutation {2}
133+ style :: FusionStyle , a:: AbstractArray , biperm_dest:: AbstractBlockPermutation {2}
69134 )
70- new_axes = fuseaxes (axes (a), biperm_dest)
71- return reshape (a, new_axes... )
135+ return matricize (style, a, blocks (biperm_dest)... )
72136end
73137
74- function matricize (a:: AbstractArray , permblock1:: Tuple , permblock2:: Tuple )
75- return matricize (a, blockedpermvcat (permblock1, permblock2; length = Val (ndims (a))))
138+ # default is reshape
139+ function matricize (:: ReshapeFusion , a:: AbstractArray , length1:: Val , length2:: Val )
140+ return reshape (a, fuseaxes (axes (a), length1, length2)... )
76141end
77142
78143# ==================================== unmatricize =======================================
79144function unmatricize (m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2} )
80- length (axes_dest) == length (invbiperm) ||
81- throw (ArgumentError (" axes do not match permutation" ))
82145 return unmatricize (FusionStyle (m), m, axes_dest, invbiperm)
83146end
84-
85147function unmatricize (
86- :: FusionStyle , m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2}
148+ style :: FusionStyle , m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2}
87149 )
150+ length (axes_dest) == length (invbiperm) ||
151+ throw (ArgumentError (" axes do not match permutation" ))
88152 blocked_axes = axes_dest[invbiperm]
89- a12 = unmatricize (m, blocked_axes)
153+ a12 = unmatricize (style, m, blocked_axes)
90154 biperm_dest = biperm (invperm (invbiperm), length_codomain (axes_dest))
91-
92155 return permuteblockeddims (a12, biperm_dest)
93156end
94157
158+ function unmatricize (
159+ m:: AbstractMatrix ,
160+ blocked_axes:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
161+ )
162+ return unmatricize (FusionStyle (m), m, blocked_axes)
163+ end
95164function unmatricize (
96165 :: ReshapeFusion ,
97166 m:: AbstractMatrix ,
@@ -100,30 +169,49 @@ function unmatricize(
100169 return reshape (m, Tuple (blocked_axes)... )
101170end
102171
103- function unmatricize (m:: AbstractMatrix , blocked_axes)
104- return unmatricize (FusionStyle (m), m, blocked_axes)
105- end
106-
107172function unmatricize (
108173 m:: AbstractMatrix ,
109174 codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
110175 domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
111176 )
177+ return unmatricize (FusionStyle (m), m, codomain_axes, domain_axes)
178+ end
179+ function unmatricize (
180+ style:: FusionStyle , m:: AbstractMatrix ,
181+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
182+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
183+ )
112184 blocked_axes = tuplemortar ((codomain_axes, domain_axes))
113- return unmatricize (m, blocked_axes)
185+ return unmatricize (style, m, blocked_axes)
114186end
115187
116- function unmatricize! (a_dest, m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2} )
188+ function unmatricize! (
189+ a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2}
190+ )
191+ return unmatricize! (FusionStyle (m), a_dest, m, invbiperm)
192+ end
193+ function unmatricize! (
194+ style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix ,
195+ invbiperm:: AbstractBlockPermutation{2} ,
196+ )
117197 ndims (a_dest) == length (invbiperm) ||
118198 throw (ArgumentError (" destination does not match permutation" ))
119199 blocked_axes = axes (a_dest)[invbiperm]
120- a_perm = unmatricize (m, blocked_axes)
200+ a_perm = unmatricize (style, m, blocked_axes)
121201 biperm_dest = biperm (invperm (invbiperm), length_codomain (axes (a_dest)))
122202 return permuteblockeddims! (a_dest, a_perm, biperm_dest)
123203end
124204
125- function unmatricizeadd! (a_dest, a_dest_mat, invbiperm, α, β)
126- a12 = unmatricize (a_dest_mat, axes (a_dest), invbiperm)
205+ function unmatricizeadd! (
206+ a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2} ,
207+ α:: Number , β:: Number
208+ )
209+ return unmatricizeadd! (FusionStyle (a_dest), a_dest, m, invbiperm, α, β)
210+ end
211+ function unmatricizeadd! (
212+ style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm, α, β
213+ )
214+ a12 = unmatricize (style, m, axes (a_dest), invbiperm)
127215 a_dest .= α .* a12 .+ β .* a_dest
128216 return a_dest
129217end
0 commit comments