1
1
module BlockSparseArraysTensorAlgebraExt
2
2
using BlockArrays: AbstractBlockedUnitRange
3
3
using GradedUnitRanges: tensor_product
4
- using TensorAlgebra: TensorAlgebra, FusionStyle , BlockReshapeFusion
4
+ using TensorAlgebra: TensorAlgebra, BlockedTuple , BlockReshapeFusion, FusionStyle
5
5
6
6
function TensorAlgebra.:⊗ (a1:: AbstractBlockedUnitRange , a2:: AbstractBlockedUnitRange )
7
7
return tensor_product (a1, a2)
@@ -19,10 +19,8 @@ function TensorAlgebra.fusedims(
19
19
return blockreshape (a, axes)
20
20
end
21
21
22
- function TensorAlgebra. splitdims (
23
- :: BlockReshapeFusion , a:: AbstractArray , axes:: AbstractUnitRange...
24
- )
25
- return blockreshape (a, axes)
22
+ function TensorAlgebra. splitdims (:: BlockReshapeFusion , a:: AbstractArray , bt:: BlockedTuple )
23
+ return blockreshape (a, Tuple (bt)... )
26
24
end
27
25
28
26
using BlockArrays:
@@ -58,7 +56,6 @@ using TensorAlgebra:
58
56
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
59
57
60
58
# TODO : Make a `ReduceWhile` library.
61
- include (" reducewhile.jl" )
62
59
63
60
TensorAlgebra. FusionStyle (:: AbstractGradedUnitRange ) = SectorFusion ()
64
61
@@ -85,15 +82,10 @@ function TensorAlgebra.fusedims(
85
82
return block_mergesort (a_reshaped)
86
83
end
87
84
88
- function TensorAlgebra. splitdims (
89
- :: SectorFusion , a:: AbstractArray , split_axes:: AbstractUnitRange...
90
- )
85
+ function TensorAlgebra. splitdims (:: SectorFusion , a:: AbstractArray , split_axes:: BlockedTuple )
91
86
# First, fuse axes to get `blockmergesortperm`.
92
87
# Then unpermute the blocks.
93
- axes_prod =
94
- groupreducewhile (tensor_product, split_axes, ndims (a); init= OneToOne ()) do i, axis
95
- return length (axis) ≤ length (axes (a, i))
96
- end
88
+ axes_prod = map (t -> tensor_product (t... ), blocks (split_axes))
97
89
blockperms = blocksortperm .(axes_prod)
98
90
sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
99
91
@@ -103,7 +95,7 @@ function TensorAlgebra.splitdims(
103
95
# for this combination of slicing.
104
96
a_unblocked = a[sorted_axes... ]
105
97
a_blockpermed = a_unblocked[invblockperm .(blockperms)... ]
106
- return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
98
+ return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes)
107
99
end
108
100
109
101
# This is a temporary fix for `eachindex` being broken for BlockSparseArrays
0 commit comments