Skip to content

Commit 9f218d7

Browse files
committed
Force specification of contract alg
1 parent 87647df commit 9f218d7

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

src/contractnetwork.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
using BackendSelection: @Algorithm_str, Algorithm
22

3-
default_sequence_alg = "leftassociative"
3+
default_contract_alg = nothing
4+
5+
#Algorithmic defaults
6+
default_sequence(::Algorithm"exact") = "leftassociative"
7+
function set_default_kwargs(alg::Algorithm"exact")
8+
sequence = get(alg, :sequence, default_sequence(alg))
9+
return Algorithm("exact"; sequence)
10+
end
411

512
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
613
return Any[i for i in 1:length(tn)]
@@ -20,11 +27,16 @@ end
2027
rearrange(tn::Vector{<:AbstractArray}, i::Integer) = tn[i]
2128
rearrange(tn::Vector{<:AbstractArray}, v::AbstractVector) = [rearrange(tn, s) for s in v]
2229

23-
function contractnetwork(tn::Vector{<:AbstractArray}; sequence=default_sequence_alg)
24-
contract_sequence = isa(sequence, String) ? contraction_sequence(tn; alg=sequence) : sequence
30+
function contractnetwork(alg::Algorithm"exact", tn::Vector{<:AbstractArray})
31+
contract_sequence = isa(alg.sequence, String) ? contraction_sequence(tn; alg=alg.sequence) : sequence
2532
return recursive_contractnetwork(rearrange(tn, contract_sequence))
2633
end
2734

28-
function contractnetwork(tn::AbstractTensorNetwork; sequence=default_sequence_alg)
29-
return contractnetwork([tn[v] for v in vertices(tn)]; sequence)
35+
function contractnetwork(alg::Algorithm"exact", tn::AbstractTensorNetwork)
36+
return contractnetwork(alg, [tn[v] for v in vertices(tn)])
3037
end
38+
39+
function contractnetwork(tn::Union{AbstractTensorNetwork, Vector{<:AbstractArray}}; alg = default_contract_alg, kwargs...)
40+
alg == nothing && error("Must specify an algorithm to contract the network with")
41+
return contractnetwork(set_default_kwargs(Algorithm(alg; kwargs...)), tn)
42+
end

test/test_contractnetwork.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ using Test: @test, @testset
1515
C = ITensor([5.0, 1.0], j)
1616
D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k)
1717

18-
ABCD_1 = contractnetwork([A, B, C, D]; sequence="leftassociative")
19-
ABCD_2 = contractnetwork([A, B, C, D]; sequence="optimal")
18+
ABCD_1 = contractnetwork([A, B, C, D]; alg = "exact", sequence="leftassociative")
19+
ABCD_2 = contractnetwork([A, B, C, D]; alg = "exact", sequence="optimal")
2020

2121
@test ABCD_1 == ABCD_2
2222
end
@@ -31,8 +31,8 @@ using Test: @test, @testset
3131
return randn(Tuple(is))
3232
end
3333

34-
z1 = contractnetwork(tn; sequence="optimal")[]
35-
z2 = contractnetwork(tn; sequence="leftassociative")[]
34+
z1 = contractnetwork(tn; alg = "exact", sequence="optimal")[]
35+
z2 = contractnetwork(tn; alg = "exact", sequence="leftassociative")[]
3636

3737
@test abs(z1 - z2) / abs(z1) <= 1e3*eps(Float64)
3838
end

0 commit comments

Comments
 (0)