@@ -8,14 +8,9 @@ using .BaseExtensions: _permutedims, _permutedims!
88# ===================================== FusionStyle ======================================
99abstract type FusionStyle end
1010
11- struct ReshapeFusion <: FusionStyle end
12-
1311FusionStyle (x) = FusionStyle (typeof (x))
1412FusionStyle (T:: Type ) = throw (MethodError (FusionStyle, (T,)))
1513
16- # Defaults to ReshapeFusion, a simple reshape
17- FusionStyle (:: Type{<:AbstractArray} ) = ReshapeFusion ()
18-
1914# ======================================= misc ========================================
2015trivial_axis (:: Tuple{} ) = Base. OneTo (1 )
2116trivial_axis (:: Tuple{Vararg{AbstractUnitRange}} ) = Base. OneTo (1 )
@@ -135,11 +130,6 @@ function matricize(
135130 return matricize (style, a, blocks (biperm_dest)... )
136131end
137132
138- # default is reshape
139- function matricize (:: ReshapeFusion , a:: AbstractArray , length1:: Val , length2:: Val )
140- return reshape (a, fuseaxes (axes (a), length1, length2)... )
141- end
142-
143133# ==================================== unmatricize =======================================
144134function unmatricize (
145135 m:: AbstractMatrix ,
@@ -164,15 +154,6 @@ function unmatricize(
164154 )
165155end
166156
167- # Implementation using reshape.
168- function unmatricize (
169- style:: ReshapeFusion , m:: AbstractMatrix ,
170- codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
171- domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
172- )
173- return reshape (m, (codomain_axes... , domain_axes... ))
174- end
175-
176157function unmatricize (
177158 m:: AbstractMatrix ,
178159 blocked_axes:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
@@ -280,3 +261,17 @@ function unmatricizeadd!(
280261 style, a_dest, m, blocks (invbiperm)... , α, β
281262 )
282263end
264+
265+ # Defaults to ReshapeFusion, a simple reshape
266+ struct ReshapeFusion <: FusionStyle end
267+ FusionStyle (:: Type{<:AbstractArray} ) = ReshapeFusion ()
268+ function matricize (style:: ReshapeFusion , a:: AbstractArray , length1:: Val , length2:: Val )
269+ return reshape (a, fuseaxes (axes (a), length1, length2))
270+ end
271+ function unmatricize (
272+ style:: ReshapeFusion , m:: AbstractMatrix ,
273+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
274+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
275+ )
276+ return reshape (m, (codomain_axes... , domain_axes... ))
277+ end
0 commit comments