Skip to content

Fix splitdims(::BlockSparseArrays) #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: AbstractBlockedUnitRange
using GradedUnitRanges: tensor_product
using GradedUnitRanges: tensor_product, gradedrange
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
Expand Down Expand Up @@ -94,13 +94,17 @@
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
return length(axis) ≤ length(axes(a, i))
end
blockperms = invblockperm.(blocksortperm.(axes_prod))
blockperms = blocksortperm.(axes_prod)
sorted_axes = ntuple(
i -> gradedrange(map(b -> length(axes_prod[i][b]), blockperms[i])), ndims(a)

Check warning on line 99 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L97-L99

Added lines #L97 - L99 were not covered by tests
)

# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
a_unblocked = a[axes_prod...]
a_blockpermed = a_unblocked[blockperms...]
a_unblocked = a[sorted_axes...]
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]

Check warning on line 107 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
end

Expand Down
Loading