Skip to content

Commit 70ca4a7

Browse files
authored
Tweak matricization overloads (#192)
1 parent d76d6ae commit 70ca4a7

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.10.14"
4+
version = "0.10.15"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -42,7 +42,7 @@ MapBroadcast = "0.1.5"
4242
MatrixAlgebraKit = "0.6"
4343
SparseArraysBase = "0.7.1"
4444
SplitApplyCombine = "1.2.3"
45-
TensorAlgebra = "0.6"
45+
TensorAlgebra = "0.6.2"
4646
Test = "1.10"
4747
TypeParameterAccessors = "0.4.1"
4848
julia = "1.10"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,27 @@ using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize,
88
matricize_axes, tensor_product_axis, unmatricize
99

1010
function TensorAlgebra.tensor_product_axis(
11-
::BlockReshapeFusion, r1::BlockUnitRange, r2::BlockUnitRange
11+
style::BlockReshapeFusion, side::Val{:codomain}, r1::BlockUnitRange, r2::BlockUnitRange
12+
)
13+
return tensor_product_blockrange(style, side, r1, r2)
14+
end
15+
function TensorAlgebra.tensor_product_axis(
16+
style::BlockReshapeFusion, side::Val{:domain}, r1::BlockUnitRange, r2::BlockUnitRange
17+
)
18+
return tensor_product_blockrange(style, side, r1, r2)
19+
end
20+
function tensor_product_blockrange(
21+
::BlockReshapeFusion, side::Val, r1::BlockUnitRange, r2::BlockUnitRange
1222
)
1323
(isone(first(r1)) && isone(first(r2))) ||
1424
throw(ArgumentError("Only one-based axes are supported"))
1525
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
16-
blockaxs = vec(splat(tensor_product_axis).(blockaxpairs))
17-
return blockrange(blockaxs)
26+
blockaxs = map(blockaxpairs) do (b1, b2)
27+
# TODO: Store a FusionStyle for the blocks in `BlockReshapeFusion`
28+
# and use that here.
29+
return tensor_product_axis(side, b1, b2)
30+
end
31+
return blockrange(vec(blockaxs))
1832
end
1933

2034
function TensorAlgebra.matricize(

0 commit comments

Comments
 (0)