Skip to content

Commit 33c6fc4

Browse files
committed
quickfix
1 parent 632a793 commit 33c6fc4

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module BlockSparseArraysTensorAlgebraExt
22
using BlockArrays: AbstractBlockedUnitRange
3-
using GradedUnitRanges: tensor_product
3+
using GradedUnitRanges: tensor_product, gradedrange
44
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
55

66
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
@@ -94,13 +94,17 @@ function TensorAlgebra.splitdims(
9494
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
9595
return length(axis) length(axes(a, i))
9696
end
97-
blockperms = invblockperm.(blocksortperm.(axes_prod))
97+
blockperms = blocksortperm.(axes_prod)
98+
sorted_axes = ntuple(
99+
i -> gradedrange(map(b -> length(axes_prod[i][b]), blockperms[i])), ndims(a)
100+
)
101+
98102
# TODO: This is doing extra copies of the blocks,
99103
# use `@view a[axes_prod...]` instead.
100104
# That will require implementing some reindexing logic
101105
# for this combination of slicing.
102-
a_unblocked = a[axes_prod...]
103-
a_blockpermed = a_unblocked[blockperms...]
106+
a_unblocked = a[sorted_axes...]
107+
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
104108
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
105109
end
106110

0 commit comments

Comments
 (0)