diff --git a/Project.toml b/Project.toml index f10faa8b..534fbf76 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.21" +version = "0.2.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 37f03d09..5e164088 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,7 +1,7 @@ module BlockSparseArraysTensorAlgebraExt using BlockArrays: AbstractBlockedUnitRange using GradedUnitRanges: tensor_product -using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion +using TensorAlgebra: TensorAlgebra, BlockedTuple, BlockReshapeFusion, FusionStyle function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) return tensor_product(a1, a2) @@ -19,10 +19,8 @@ function TensorAlgebra.fusedims( return blockreshape(a, axes) end -function TensorAlgebra.splitdims( - ::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange... -) - return blockreshape(a, axes) +function TensorAlgebra.splitdims(::BlockReshapeFusion, a::AbstractArray, bt::BlockedTuple) + return blockreshape(a, Tuple(bt)...) end using BlockArrays: @@ -58,7 +56,6 @@ using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims # TODO: Make a `ReduceWhile` library. -include("reducewhile.jl") TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion() @@ -85,15 +82,10 @@ function TensorAlgebra.fusedims( return block_mergesort(a_reshaped) end -function TensorAlgebra.splitdims( - ::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange... -) +function TensorAlgebra.splitdims(::SectorFusion, a::AbstractArray, split_axes::BlockedTuple) # First, fuse axes to get `blockmergesortperm`. # Then unpermute the blocks. - axes_prod = - groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis - return length(axis) ≤ length(axes(a, i)) - end + axes_prod = map(t -> tensor_product(t...), blocks(split_axes)) blockperms = blocksortperm.(axes_prod) sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms) @@ -103,7 +95,7 @@ function TensorAlgebra.splitdims( # for this combination of slicing. a_unblocked = a[sorted_axes...] a_blockpermed = a_unblocked[invblockperm.(blockperms)...] - return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) + return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes) end # This is a temporary fix for `eachindex` being broken for BlockSparseArrays