Skip to content

Commit fc0668a

Browse files
authored
fix tensor_product for zero and one blocklength axes (#13)
1 parent aa2e5f9 commit fc0668a

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorProducts"
22
uuid = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[weakdeps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

ext/TensorProductsBlockArraysExt/TensorProductsBlockArraysExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ using TensorProducts: OneToOne, TensorProducts
1414
function TensorProducts.tensor_product(
1515
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
1616
)
17-
new_blocklengths = mapreduce(vcat, Iterators.product(blocks(a1), blocks(a2))) do (x, y)
18-
return length(x) * length(y)
19-
end
17+
new_blocklengths = vec(
18+
map(splat(*), Iterators.product(blocklengths(a1), blocklengths(a2)))
19+
)
2020
return blockedrange(new_blocklengths)
2121
end
2222

test/test_tensor_product.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using TensorProducts: ⊗, OneToOne, tensor_product
55
using BlockArrays: blockedrange, blockisequal
66

77
r0 = OneToOne()
8-
b1 = blockedrange([1, 2])
8+
b0 = blockedrange(Int[])
9+
b1 = blockedrange([1])
10+
b2 = blockedrange([1, 2])
911

1012
@testset "tensor_product" begin
1113
@test tensor_product() isa OneToOne
@@ -21,8 +23,13 @@ b1 = blockedrange([1, 2])
2123
@test blockisequal(tensor_product(b1, r0), b1)
2224
@test blockisequal(tensor_product(r0, b1), b1)
2325

24-
@test blockisequal(tensor_product(b1, b1), blockedrange([1, 2, 2, 4]))
25-
@test blockisequal(tensor_product(b1, b1), blockedrange([1, 2, 2, 4]))
26+
@test blockisequal(tensor_product(b0, b0), b0)
27+
@test blockisequal(tensor_product(b0, b1), b0)
28+
@test blockisequal(tensor_product(b1, b0), b0)
29+
@test blockisequal(tensor_product(b1, b1), b1)
30+
@test blockisequal(tensor_product(b1, b2), b2)
31+
@test blockisequal(tensor_product(b2, b1), b2)
32+
@test blockisequal(tensor_product(b2, b2), blockedrange([1, 2, 2, 4]))
2633

2734
@test (r0, r0) isa OneToOne
2835
end

0 commit comments

Comments
 (0)