|
1 | 1 | module BlockSparseArraysTensorAlgebraExt |
2 | 2 | using BlockArrays: AbstractBlockedUnitRange |
3 | | -using GradedUnitRanges: tensor_product |
| 3 | +using GradedUnitRanges: tensor_product, gradedrange |
4 | 4 | using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion |
5 | 5 |
|
6 | 6 | function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) |
@@ -94,13 +94,17 @@ function TensorAlgebra.splitdims( |
94 | 94 | groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis |
95 | 95 | return length(axis) ≤ length(axes(a, i)) |
96 | 96 | end |
97 | | - blockperms = invblockperm.(blocksortperm.(axes_prod)) |
| 97 | + blockperms = blocksortperm.(axes_prod) |
| 98 | + sorted_axes = ntuple( |
| 99 | + i -> gradedrange(map(b -> length(axes_prod[i][b]), blockperms[i])), ndims(a) |
| 100 | + ) |
| 101 | + |
98 | 102 | # TODO: This is doing extra copies of the blocks, |
99 | 103 | # use `@view a[axes_prod...]` instead. |
100 | 104 | # That will require implementing some reindexing logic |
101 | 105 | # for this combination of slicing. |
102 | | - a_unblocked = a[axes_prod...] |
103 | | - a_blockpermed = a_unblocked[blockperms...] |
| 106 | + a_unblocked = a[sorted_axes...] |
| 107 | + a_blockpermed = a_unblocked[invblockperm.(blockperms)...] |
104 | 108 | return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) |
105 | 109 | end |
106 | 110 |
|
|
0 commit comments