Skip to content

Commit 2ada7cd

Browse files
committed
Fix tests
1 parent 5b074f1 commit 2ada7cd

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@ module ITensorNetworksNextTensorOperationsExt
22

33
using BackendSelection: @Algorithm_str, Algorithm
44
using ITensorNetworksNext: ITensorNetworksNext, contraction_order
5+
using ITensorNetworksNext.LazyNamedDimsArrays: symnameddims, substitute
56
using NamedDimsArrays: inds
67
using TensorOperations: TensorOperations, optimaltree
78

8-
function contraction_order_to_expr(seq)
9-
return seq isa AbstractVector ? prod(contraction_order_to_expr, seq) : symnameddims(seq)
9+
function contraction_order_to_expr(ord)
10+
return ord isa AbstractVector ? prod(contraction_order_to_expr, ord) : symnameddims(ord)
1011
end
1112

1213
function ITensorNetworksNext.contraction_order(alg::Algorithm"optimal", tn)
13-
ts = [tn[i] for i in eachindex(tn)]
14+
ts = [tn[i] for i in keys(tn)]
1415
network = collect.(inds.(ts))
1516
# Converting dims to Float64 to minimize overflow issues
1617
inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network)))
17-
seq, _ = optimaltree(network, inds_to_dims)
18+
order, _ = optimaltree(network, inds_to_dims)
1819
# TODO: Map the integer indices back to the original tensor network vertices.
19-
expr = contraction_order_to_expr(seq)
20-
subs = Dict(symnameddims(i) => symnameddims(eachindex(tn)[i]) for i in eachindex(ts))
20+
expr = contraction_order_to_expr(order)
21+
verts = collect(keys(tn))
22+
sym(i) = symnameddims(verts[i], Tuple(inds(tn[verts[i]])))
23+
subs = Dict(symnameddims(i) => sym(i) for i in eachindex(verts))
2124
return substitute(expr, subs)
2225
end
2326

test/test_contract_network.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ using Test: @test, @testset
1414
C = ITensor([5.0, 1.0], j)
1515
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
1616

17-
ABCD_1 = contract_network([A, B, C, D]; alg = "exact", order_alg = "leftassociative")
18-
ABCD_2 = contract_network([A, B, C, D]; alg = "exact", order_alg = "optimal")
17+
ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative")
18+
ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager")
19+
ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal")
1920

20-
@test ABCD_1 == ABCD_2
21+
@test ABCD_1 == ABCD_2 == ABCD_3
2122
end
2223

2324
@testset "Contract One Dimensional Network" begin
@@ -30,9 +31,11 @@ using Test: @test, @testset
3031
return randn(Tuple(is))
3132
end
3233

33-
z1 = contract_network(tn; alg = "exact", sequence_alg = "optimal")[]
34-
z2 = contract_network(tn; alg = "exact", sequence_alg = "leftassociative")[]
34+
z1 = contract_network(tn; order_alg = "left_associative")[]
35+
z2 = contract_network(tn; order_alg = "eager")[]
36+
z3 = contract_network(tn; order_alg = "optimal")[]
3537

3638
@test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64)
39+
@test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64)
3740
end
3841
end

0 commit comments

Comments
 (0)