Skip to content

Commit 7eace5b

Browse files
committed
Use internal ITensorNetworks code for sequence finding
1 parent 85aeb23 commit 7eace5b

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ using NDTensors: NDTensors
1717
abstract type AbstractBeliefPropagationCache end
1818

1919
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
20-
updated_messages = contract(contract_list; sequence="automatic", kwargs...)
20+
sequence = contraction_sequence(contract_list; alg="optimal")
21+
updated_messages = contract(contract_list; sequence, kwargs...)
2122
message_norm = norm(updated_messages)
2223
if normalize && !iszero(message_norm)
2324
updated_messages /= message_norm

src/caches/beliefpropagationcache.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,20 @@ function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...)
9494
return vcat(messages, central_tensors)
9595
end
9696

97-
function region_scalar(
98-
bp_cache::BeliefPropagationCache,
99-
pv::PartitionVertex;
100-
contract_kwargs=(; sequence="automatic"),
101-
)
97+
function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex;)
10298
incoming_mts = incoming_messages(bp_cache, [pv])
10399
local_state = factors(bp_cache, pv)
104-
return contract(vcat(incoming_mts, local_state); contract_kwargs...)[]
100+
ts = vcat(incoming_mts, local_state)
101+
sequence = contraction_sequence(ts; alg="optimal")
102+
return contract(ts; sequence)[]
105103
end
106104

107105
function region_scalar(
108106
bp_cache::BeliefPropagationCache,
109107
pe::PartitionEdge;
110108
contract_kwargs=(; sequence="automatic"),
111109
)
112-
return contract(
113-
vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))); contract_kwargs...
114-
)[]
110+
ts = vcat(message(bp_cache, pe), message(bp_cache, reverse(pe)))
111+
sequence = contraction_sequence(ts; alg="optimal")
112+
return contract(ts; sequence)[]
115113
end

src/expect.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@ using ITensorMPS: ITensorMPS, expect
44

55
default_expect_alg() = "bp"
66

7-
function ITensorMPS.expect(
8-
ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=(; sequence="automatic"), kwargs...
9-
)
7+
function ITensorMPS.expect(ψIψ::AbstractFormNetwork, op::Op; kwargs...)
108
v = only(op.sites)
119
ψIψ_v = ψIψ[operator_vertex(ψIψ, v)]
1210
s = commonind(ψIψ[ket_vertex(ψIψ, v)], ψIψ_v)
1311
operator = ITensors.op(op.which_op, s)
1412
∂ψIψ_∂v = environment(ψIψ, operator_vertices(ψIψ, [v]); kwargs...)
15-
numerator = contract(vcat(∂ψIψ_∂v, operator); contract_kwargs...)[]
16-
denominator = contract(vcat(∂ψIψ_∂v, ψIψ_v); contract_kwargs...)[]
13+
numerator_ts = vcat(∂ψIψ_∂v, operator)
14+
denominator_ts = vcat(∂ψIψ_∂v, ψIψ_v)
15+
numerator_seq = contraction_sequence(numerator_ts; alg="optimal")
16+
denominator_seq = contraction_sequence(denominator_ts; alg="optimal")
17+
numerator = contract(numerator_ts; sequence)[]
18+
denominator = contract(denominator_ts; sequence)[]
1719

1820
return numerator / denominator
1921
end

0 commit comments

Comments
 (0)