Skip to content

Commit 389f148

Browse files
committed
Simplify interface, fix tests
1 parent 15a87a5 commit 389f148

File tree

7 files changed

+77
-83
lines changed

7 files changed

+77
-83
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,6 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525
end
2626
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))
2727

28-
function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor})
29-
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
30-
updated_messages = contract(contract_list; sequence)
31-
message_norm = norm(updated_messages)
32-
if alg.kwargs.normalize && !iszero(message_norm)
33-
updated_messages /= message_norm
34-
end
35-
return ITensor[updated_messages]
36-
end
37-
38-
function message_update(alg::Algorithm"adapt_update", contract_list::Vector{ITensor})
39-
adapted_contract_list = alg.kwargs.adapt.(contract_list)
40-
updated_messages = message_update(alg.kwargs.alg, adapted_contract_list)
41-
dtype = mapreduce(datatype, promote_type, contract_list)
42-
return map(adapt(dtype), updated_messages)
43-
end
44-
4528
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
4629
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
4730
lhs, rhs = contract(message_a), contract(message_b)
@@ -144,22 +127,28 @@ function incoming_messages(
144127
end
145128

146129
#Adapt interface for changing device
147-
function map_messages(f, bpc::AbstractBeliefPropagationCache)
130+
function map_messages(
131+
f, bpc::AbstractBeliefPropagationCache, pes=collect(keys(messages(bpc)))
132+
)
148133
bpc = copy(bpc)
149-
for pe in keys(messages(bpc))
134+
for pe in pes
150135
set_message!(bpc, pe, f.(message(bpc, pe)))
151136
end
152137
return bpc
153138
end
154-
function map_factors(f, bpc::AbstractBeliefPropagationCache)
139+
function map_factors(f, bpc::AbstractBeliefPropagationCache, vs=vertices(bpc))
155140
bpc = copy(bpc)
156-
for v in vertices(bpc)
141+
for v in vs
157142
@preserve_graph bpc[v] = f(bpc[v])
158143
end
159144
return bpc
160145
end
161-
adapt_messages(to, bpc::AbstractBeliefPropagationCache) = map_messages(adapt(to), bpc)
162-
adapt_factors(to, bpc::AbstractBeliefPropagationCache) = map_factors(adapt(to), bpc)
146+
function adapt_messages(to, bpc::AbstractBeliefPropagationCache, args...)
147+
map_messages(adapt(to), bpc, args...)
148+
end
149+
function adapt_factors(to, bpc::AbstractBeliefPropagationCache, args...)
150+
map_factors(adapt(to), bpc, args...)
151+
end
163152

164153
function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
165154
bpc = adapt_messages(to, bpc)
@@ -257,33 +246,49 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
257246
return delete_messages(bpc, [pe])
258247
end
259248

260-
"""
261-
Compute message tensor as product of incoming mts and local state
262-
"""
263249
function updated_message(
264-
bpc::AbstractBeliefPropagationCache,
265-
edge::PartitionEdge;
266-
message_update_alg=default_message_update_alg(bpc),
267-
kwargs...,
250+
alg::Algorithm"contract", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
268251
)
269252
vertex = src(edge)
270253
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
271254
state = factors(bpc, vertex)
255+
contract_list = ITensor[incoming_ms; state]
256+
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
257+
updated_messages = contract(contract_list; sequence)
258+
message_norm = norm(updated_messages)
259+
if alg.kwargs.normalize && !iszero(message_norm)
260+
updated_messages /= message_norm
261+
end
262+
return ITensor[updated_messages]
263+
end
272264

273-
return message_update(message_update_alg, ITensor[incoming_ms; state]; kwargs...)
265+
function updated_message(
266+
alg::Algorithm"adapt_update", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
267+
)
268+
incoming_pes = setdiff(
269+
boundary_partitionedges(bpc, [src(edge)]; dir=:in), [reverse(edge)]
270+
)
271+
adapted_bpc = adapt_messages(alg.kwargs.adapt, bpc, incoming_pes)
272+
adapted_bpc = adapt_factors(alg.kwargs.adapt, bpc, vertices(bpc, src(edge)))
273+
updated_messages = updated_message(alg.kwargs.alg, adapted_bpc, edge)
274+
dtype = mapreduce(datatype, promote_type, message(bpc, edge))
275+
return map(adapt(dtype), updated_messages)
274276
end
275277

276-
function update(
277-
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
278+
function update_message(
279+
message_update_alg::Algorithm,
280+
bpc::AbstractBeliefPropagationCache,
281+
edge::PartitionEdge;
282+
kwargs...,
278283
)
279-
return set_message(bpc, edge, updated_message(bpc, edge; kwargs...))
284+
return set_message(bpc, edge, updated_message(message_update_alg, bpc, edge; kwargs...))
280285
end
281286

282287
"""
283288
Do a sequential update of the message tensors on `edges`
284289
"""
285-
function update(
286-
alg::Algorithm,
290+
function update_one_iteration(
291+
alg::Algorithm"bp",
287292
bpc::AbstractBeliefPropagationCache,
288293
edges::Vector;
289294
(update_diff!)=nothing,
@@ -292,7 +297,7 @@ function update(
292297
bpc = copy(bpc)
293298
for e in edges
294299
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
295-
bpc = update(alg, bpc, e; kwargs...)
300+
bpc = update_message(alg.kwargs.message_update_alg, bpc, e; kwargs...)
296301
if !isnothing(update_diff!)
297302
update_diff![] += message_diff(message(bpc, e), prev_message)
298303
end
@@ -305,15 +310,15 @@ Do parallel updates between groups of edges of all message tensors
305310
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
306311
mts relevant to that group.
307312
"""
308-
function update(
313+
function update_one_iteration(
309314
alg::Algorithm,
310315
bpc::AbstractBeliefPropagationCache,
311316
edge_groups::Vector{<:Vector{<:PartitionEdge}};
312317
kwargs...,
313318
)
314319
new_mts = empty(messages(bpc))
315320
for edges in edge_groups
316-
bpc_t = update(alg, bpc, edges; kwargs...)
321+
bpc_t = update_one_iteration(alg.kwargs.message_update_alg, bpc, edges; kwargs...)
317322
for e in edges
318323
set!(new_mts, e, message(bpc_t, e))
319324
end
@@ -324,27 +329,17 @@ end
324329
"""
325330
More generic interface for update, with default params
326331
"""
327-
function update(
328-
alg::Algorithm,
329-
bpc::AbstractBeliefPropagationCache;
330-
message_update_alg=default_message_update_alg(bpc),
331-
kwargs...,
332-
)
332+
function update(alg::Algorithm, bpc::AbstractBeliefPropagationCache; kwargs...)
333333
compute_error = !isnothing(alg.kwargs.tol)
334334
if isnothing(alg.kwargs.maxiter)
335335
error("You need to specify a number of iterations for BP!")
336336
end
337337
for i in 1:alg.kwargs.maxiter
338338
diff = compute_error ? Ref(0.0) : nothing
339-
bpc = update(
340-
alg,
341-
bpc,
342-
alg.kwargs.edge_sequence;
343-
(update_diff!)=diff,
344-
message_update_alg=set_default_kwargs(message_update_alg),
345-
kwargs...,
339+
bpc = update_one_iteration(
340+
alg, bpc, alg.kwargs.edge_sequence; (update_diff!)=diff, kwargs...
346341
)
347-
if compute_error && (diff.x / length(edges)) <= alg.kwargs.tol
342+
if compute_error && (diff.x / length(alg.kwargs.edge_sequence)) <= alg.kwargs.tol
348343
if alg.kwargs.verbose
349344
println("BP converged to desired precision after $i iterations.")
350345
end
@@ -355,7 +350,7 @@ function update(
355350
end
356351

357352
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
358-
return update(set_default_kwargs(alg, bpc), bpc; kwargs...)
353+
return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc)
359354
end
360355

361356
function rescale_messages(

src/caches/beliefpropagationcache.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,8 @@ function Base.copy(bp_cache::BeliefPropagationCache)
6868
)
6969
end
7070

71-
function default_update_alg(bp_cache::BeliefPropagationCache)
72-
Algorithm(
73-
"bp";
74-
verbose=false,
75-
maxiter=default_bp_maxiter(bp_cache),
76-
edge_sequence=default_bp_edge_sequence(bp_cache),
77-
tol=nothing,
78-
)
79-
end
80-
function default_message_update_alg(bp_cache::BeliefPropagationCache)
81-
return Algorithm("contract"; normalize=true, sequence_alg="optimal")
82-
end
71+
default_update_alg(bp_cache::BeliefPropagationCache) = "bp"
72+
default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract"
8373
default_normalize(::Algorithm"contract") = true
8474
default_sequence_alg(::Algorithm"contract") = "optimal"
8575
function set_default_kwargs(alg::Algorithm"contract")
@@ -88,9 +78,8 @@ function set_default_kwargs(alg::Algorithm"contract")
8878
return Algorithm("contract"; normalize, sequence_alg)
8979
end
9080
function set_default_kwargs(alg::Algorithm"adapt_update")
91-
return Algorithm(
92-
"adapt_update"; adapt=alg.kwargs.adapt, alg=set_default_kwargs(alg.kwargs.alg)
93-
)
81+
_alg = set_default_kwargs(get(alg.kwargs, :alg, Algorithm("contract")))
82+
return Algorithm("adapt_update"; adapt=alg.kwargs.adapt, alg=_alg)
9483
end
9584
default_verbose(::Algorithm"bp") = false
9685
default_tol(::Algorithm"bp") = nothing
@@ -99,7 +88,10 @@ function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache
9988
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bp_cache))
10089
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bp_cache))
10190
tol = get(alg.kwargs, :tol, default_tol(alg))
102-
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol)
91+
message_update_alg = set_default_kwargs(
92+
get(alg.kwargs, :message_update_alg, Algorithm(default_message_update_alg(bp_cache)))
93+
)
94+
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg)
10395
end
10496

10597
function default_bp_maxiter(bp_cache::BeliefPropagationCache)

test/test_apply.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ using Test: @test, @testset
3030
ψψ = norm_sqr_network(ψ)
3131
#Simple Belief Propagation Grouping
3232
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
33-
bp_cache = update(bp_cache; alg=Algorithm("bp"; maxiter=20))
33+
bp_cache = update(bp_cache; maxiter=20)
3434
envsSBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")])
35-
ψv = VidalITensorNetwork(ψ)
35+
ψv = VidalITensorNetwork; cache_update_kwargs=(; maxiter=20))
3636
#This grouping will correspond to calculating the environments exactly (each column of the grid is a partition)
3737
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ)))
38-
bp_cache = update(bp_cache; alg=Algorithm("bp"; maxiter=20))
38+
bp_cache = update(bp_cache; maxiter=20)
3939
envsGBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")])
4040
inner_alg = "exact"
4141
ngates = 5

test/test_belief_propagation.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ using ITensorNetworks:
2525
update_factor,
2626
updated_message,
2727
message_diff
28-
using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random_itensor
28+
using ITensors:
29+
ITensors, ITensor, Algorithm, combiner, dag, inds, inner, op, prime, random_itensor
2930
using ITensorNetworks.ModelNetworks: ModelNetworks
3031
using ITensors.NDTensors: array
3132
using LinearAlgebra: eigvals, tr
@@ -62,7 +63,7 @@ using Test: @test, @testset
6263
@test bpc[vket] == new_A
6364
@test bpc[vbra] == new_A_dag
6465

65-
bpc = update(bpc; alg=Algorithm("bp"; maxiter=25, tol=eps(real(elt))))
66+
bpc = update(bpc; alg="bp", maxiter=25, tol=eps(real(elt)))
6667
#Test messages are converged
6768
for pe in partitionedges(bpc)
6869
@test message_diff(updated_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))

test/test_expect.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ using Test: @test, @testset
3636
cache_construction_kwargs = (;
3737
partitioned_vertices=group(v -> first(first(v)), quadratic_form_vertices)
3838
)
39-
sz_bp = expect(ψ, "Sz"; alg="bp", cache_construction_kwargs)
39+
sz_bp = expect(
40+
ψ, "Sz"; alg="bp", cache_construction_kwargs, cache_update_kwargs=(; maxiter=20)
41+
)
4042
sz_exact = expect(ψ, "Sz"; alg="exact")
4143
@test sz_bp sz_exact
4244

@@ -47,7 +49,7 @@ using Test: @test, @testset
4749

4850
ψ = ITensorNetwork(v -> isodd(sum(v)) ? "" : "", s)
4951

50-
sz_bp = expect(ψ, "Sz"; alg="bp")
52+
sz_bp = expect(ψ, "Sz"; alg="bp", cache_update_kwargs=(; maxiter=20))
5153
sz_exact = expect(ψ, "Sz"; alg="exact")
5254
@test sz_bp sz_exact
5355
end

test/test_gauging.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using Test: @test, @testset
2929

3030
# Move directly to vidal gauge
3131
ψ_vidal = VidalITensorNetwork(
32-
ψ; cache_update_kwargs=(; alg=Algorithm("bp"; maxiter=30, verbose=true))
32+
ψ; cache_update_kwargs=(; alg="bp", maxiter=30, verbose=true)
3333
)
3434
@test gauge_error(ψ_vidal) < 1e-8
3535

@@ -39,10 +39,12 @@ using Test: @test, @testset
3939
bp_cache = cache_ref[]
4040

4141
# Test we just did a gauge transform and didn't change the overall network
42-
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0 atol = 1e-8
42+
@test inner(ψ_symm, ψ; alg="exact") /
43+
sqrt(inner(ψ_symm, ψ_symm; alg="exact") * inner(ψ, ψ; alg="exact")) 1.0 atol =
44+
1e-8
4345

4446
#Test all message tensors are approximately diagonal even when we keep running BP
45-
bp_cache = update(bp_cache; alg=Algorithm("bp"; maxiter=10))
47+
bp_cache = update(bp_cache; maxiter=10)
4648
for m_e in values(messages(bp_cache))
4749
@test diag_itensor(vector(diag(only(m_e))), inds(only(m_e))) only(m_e) atol = 1e-8
4850
end

test/test_normalize.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using Test: @test, @testset
2929
tn_r = rescale(tn; alg="exact")
3030
@test scalar(tn_r; alg="exact") 1.0
3131

32-
tn_r = rescale(tn; alg="bp")
32+
tn_r = rescale(tn; alg="bp", cache_update_kwargs=(; maxiter=20))
3333
@test scalar(tn_r; alg="exact") 1.0
3434

3535
#Now a state on a loopy graph
@@ -45,10 +45,12 @@ using Test: @test, @testset
4545
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
4646

4747
ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x)))
48-
ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true)
48+
ψ = normalize(
49+
x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true, cache_update_kwargs=(; maxiter=20)
50+
)
4951
ψIψ_bpc = ψIψ_bpc[]
5052
@test all(x -> x 1.0, edge_scalars(ψIψ_bpc))
5153
@test all(x -> x 1.0, vertex_scalars(ψIψ_bpc))
52-
@test scalar(QuadraticFormNetwork(ψ); alg="bp") 1.0
54+
@test scalar(QuadraticFormNetwork(ψ); alg="bp", cache_update_kwargs=(; maxiter=20)) 1.0
5355
end
5456
end

0 commit comments

Comments
 (0)