Skip to content

Commit e01f28a

Browse files
committed
use BlockedTuple
1 parent e1219b1 commit e01f28a

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module BlockSparseArraysTensorAlgebraExt
22
using BlockArrays: AbstractBlockedUnitRange
33
using GradedUnitRanges: tensor_product
4-
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
4+
using TensorAlgebra: TensorAlgebra, BlockedTuple, BlockReshapeFusion, FusionStyle
55

66
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
77
return tensor_product(a1, a2)
@@ -19,10 +19,8 @@ function TensorAlgebra.fusedims(
1919
return blockreshape(a, axes)
2020
end
2121

22-
function TensorAlgebra.splitdims(
23-
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
24-
)
25-
return blockreshape(a, axes)
22+
function TensorAlgebra.splitdims(::BlockReshapeFusion, a::AbstractArray, bt::BlockedTuple)
23+
return blockreshape(a, Tuple(bt)...)
2624
end
2725

2826
using BlockArrays:
@@ -58,7 +56,6 @@ using TensorAlgebra:
5856
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
5957

6058
# TODO: Make a `ReduceWhile` library.
61-
include("reducewhile.jl")
6259

6360
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
6461

@@ -85,15 +82,10 @@ function TensorAlgebra.fusedims(
8582
return block_mergesort(a_reshaped)
8683
end
8784

88-
function TensorAlgebra.splitdims(
89-
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
90-
)
85+
function TensorAlgebra.splitdims(::SectorFusion, a::AbstractArray, split_axes::BlockedTuple)
9186
# First, fuse axes to get `blockmergesortperm`.
9287
# Then unpermute the blocks.
93-
axes_prod =
94-
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
95-
return length(axis) length(axes(a, i))
96-
end
88+
axes_prod = map(t -> tensor_product(t...), blocks(split_axes))
9789
blockperms = blocksortperm.(axes_prod)
9890
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
9991

@@ -103,7 +95,7 @@ function TensorAlgebra.splitdims(
10395
# for this combination of slicing.
10496
a_unblocked = a[sorted_axes...]
10597
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
106-
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
98+
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes)
10799
end
108100

109101
# This is a temporary fix for `eachindex` being broken for BlockSparseArrays

0 commit comments

Comments
 (0)