11module BlockSparseArraysTensorAlgebraExt
22using BlockArrays: AbstractBlockedUnitRange
33using GradedUnitRanges: tensor_product
4- using TensorAlgebra: TensorAlgebra, FusionStyle , BlockReshapeFusion
4+ using TensorAlgebra: TensorAlgebra, BlockedTuple , BlockReshapeFusion, FusionStyle
55
66function TensorAlgebra.:⊗ (a1:: AbstractBlockedUnitRange , a2:: AbstractBlockedUnitRange )
77 return tensor_product (a1, a2)
@@ -19,10 +19,8 @@ function TensorAlgebra.fusedims(
1919 return blockreshape (a, axes)
2020end
2121
22- function TensorAlgebra. splitdims (
23- :: BlockReshapeFusion , a:: AbstractArray , axes:: AbstractUnitRange...
24- )
25- return blockreshape (a, axes)
22+ function TensorAlgebra. splitdims (:: BlockReshapeFusion , a:: AbstractArray , bt:: BlockedTuple )
23+ return blockreshape (a, Tuple (bt)... )
2624end
2725
2826using BlockArrays:
@@ -58,7 +56,6 @@ using TensorAlgebra:
5856 TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
5957
6058# TODO : Make a `ReduceWhile` library.
61- include (" reducewhile.jl" )
6259
6360TensorAlgebra. FusionStyle (:: AbstractGradedUnitRange ) = SectorFusion ()
6461
@@ -85,15 +82,10 @@ function TensorAlgebra.fusedims(
8582 return block_mergesort (a_reshaped)
8683end
8784
88- function TensorAlgebra. splitdims (
89- :: SectorFusion , a:: AbstractArray , split_axes:: AbstractUnitRange...
90- )
85+ function TensorAlgebra. splitdims (:: SectorFusion , a:: AbstractArray , split_axes:: BlockedTuple )
9186 # First, fuse axes to get `blockmergesortperm`.
9287 # Then unpermute the blocks.
93- axes_prod =
94- groupreducewhile (tensor_product, split_axes, ndims (a); init= OneToOne ()) do i, axis
95- return length (axis) ≤ length (axes (a, i))
96- end
88+ axes_prod = map (t -> tensor_product (t... ), blocks (split_axes))
9789 blockperms = blocksortperm .(axes_prod)
9890 sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
9991
@@ -103,7 +95,7 @@ function TensorAlgebra.splitdims(
10395 # for this combination of slicing.
10496 a_unblocked = a[sorted_axes... ]
10597 a_blockpermed = a_unblocked[invblockperm .(blockperms)... ]
106- return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
98+ return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes)
10799end
108100
109101# This is a temporary fix for `eachindex` being broken for BlockSparseArrays
0 commit comments