diff --git a/Project.toml b/Project.toml index 60b8c3c..321cdfb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.10.14" +version = "0.10.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -42,7 +42,7 @@ MapBroadcast = "0.1.5" MatrixAlgebraKit = "0.6" SparseArraysBase = "0.7.1" SplitApplyCombine = "1.2.3" -TensorAlgebra = "0.6" +TensorAlgebra = "0.6.2" Test = "1.10" TypeParameterAccessors = "0.4.1" julia = "1.10" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 4874dcd..3a71e90 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -8,13 +8,27 @@ using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize, matricize_axes, tensor_product_axis, unmatricize function TensorAlgebra.tensor_product_axis( - ::BlockReshapeFusion, r1::BlockUnitRange, r2::BlockUnitRange + style::BlockReshapeFusion, side::Val{:codomain}, r1::BlockUnitRange, r2::BlockUnitRange + ) + return tensor_product_blockrange(style, side, r1, r2) +end +function TensorAlgebra.tensor_product_axis( + style::BlockReshapeFusion, side::Val{:domain}, r1::BlockUnitRange, r2::BlockUnitRange + ) + return tensor_product_blockrange(style, side, r1, r2) +end +function tensor_product_blockrange( + ::BlockReshapeFusion, side::Val, r1::BlockUnitRange, r2::BlockUnitRange ) (isone(first(r1)) && isone(first(r2))) || throw(ArgumentError("Only one-based axes are supported")) blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) - blockaxs = vec(splat(tensor_product_axis).(blockaxpairs)) - return blockrange(blockaxs) + blockaxs = map(blockaxpairs) do (b1, b2) + # TODO: Store a FusionStyle for the blocks in `BlockReshapeFusion` + # and use that here. + return tensor_product_axis(side, b1, b2) + end + return blockrange(vec(blockaxs)) end function TensorAlgebra.matricize(