@@ -8,13 +8,27 @@ using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize,
88 matricize_axes, tensor_product_axis, unmatricize
99
1010function TensorAlgebra. tensor_product_axis (
11- :: BlockReshapeFusion , r1:: BlockUnitRange , r2:: BlockUnitRange
11+ style:: BlockReshapeFusion , side:: Val{:codomain} , r1:: BlockUnitRange , r2:: BlockUnitRange
12+ )
13+ return tensor_product_blockrange (style, side, r1, r2)
14+ end
15+ function TensorAlgebra. tensor_product_axis (
16+ style:: BlockReshapeFusion , side:: Val{:domain} , r1:: BlockUnitRange , r2:: BlockUnitRange
17+ )
18+ return tensor_product_blockrange (style, side, r1, r2)
19+ end
20+ function tensor_product_blockrange (
21+ :: BlockReshapeFusion , side:: Val , r1:: BlockUnitRange , r2:: BlockUnitRange
1222 )
1323 (isone (first (r1)) && isone (first (r2))) ||
1424 throw (ArgumentError (" Only one-based axes are supported" ))
1525 blockaxpairs = Iterators. product (eachblockaxes1 (r1), eachblockaxes1 (r2))
16- blockaxs = vec (splat (tensor_product_axis).(blockaxpairs))
17- return blockrange (blockaxs)
26+ blockaxs = map (blockaxpairs) do (b1, b2)
27+ # TODO : Store a FusionStyle for the blocks in `BlockReshapeFusion`
28+ # and use that here.
29+ return tensor_product_axis (side, b1, b2)
30+ end
31+ return blockrange (vec (blockaxs))
1832end
1933
2034function TensorAlgebra. matricize (
0 commit comments