Skip to content

Commit 6222386

Browse files
committed
set kwargs on algorithms
1 parent 1432a4d commit 6222386

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525
end
2626
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))
2727

28-
function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor};)
28+
function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor})
2929
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
3030
updated_messages = contract(contract_list; sequence)
3131
message_norm = norm(updated_messages)
@@ -35,10 +35,10 @@ function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor}
3535
return ITensor[updated_messages]
3636
end
3737

38-
function message_update(alg::Algorithm"adapt_update", contract_list::Vector{ITensor};)
38+
function message_update(alg::Algorithm"adapt_update", contract_list::Vector{ITensor})
3939
adapted_contract_list = alg.kwargs.adapt.(contract_list)
4040
updated_messages = message_update(alg.kwargs.alg, adapted_contract_list)
41-
dtype = datatype(first(contract_list))
41+
dtype = mapreduce(datatype, promote_type, contract_list)
4242
return map(adapt(dtype), updated_messages)
4343
end
4444

@@ -327,21 +327,25 @@ More generic interface for update, with default params
327327
function update(
328328
alg::Algorithm,
329329
bpc::AbstractBeliefPropagationCache;
330-
edges=alg.kwargs.edge_sequence,
331-
tol=alg.kwargs.tol,
332-
maxiter=alg.kwargs.maxiter,
333-
verbose=alg.kwargs.verbose,
330+
message_update_alg=default_message_update_alg(bpc),
334331
kwargs...,
335332
)
336-
compute_error = !isnothing(tol)
337-
if isnothing(maxiter)
333+
compute_error = !isnothing(alg.kwargs.tol)
334+
if isnothing(alg.kwargs.maxiter)
338335
error("You need to specify a number of iterations for BP!")
339336
end
340-
for i in 1:maxiter
337+
for i in 1:alg.kwargs.maxiter
341338
diff = compute_error ? Ref(0.0) : nothing
342-
bpc = update(alg, bpc, edges; (update_diff!)=diff, kwargs...)
343-
if compute_error && (diff.x / length(edges)) <= tol
344-
if verbose
339+
bpc = update(
340+
alg,
341+
bpc,
342+
alg.kwargs.edge_sequence;
343+
(update_diff!)=diff,
344+
message_update_alg=set_kwargs(message_update_alg),
345+
kwargs...,
346+
)
347+
if compute_error && (diff.x / length(edges)) <= alg.kwargs.tol
348+
if alg.kwargs.verbose
345349
println("BP converged to desired precision after $i iterations.")
346350
end
347351
break
@@ -351,7 +355,7 @@ function update(
351355
end
352356

353357
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
354-
return update(Algorithm(alg), bpc; kwargs...)
358+
return update(set_kwargs(alg, bpc), bpc; kwargs...)
355359
end
356360

357361
function rescale_messages(

src/caches/beliefpropagationcache.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,30 @@ end
7171
function default_update_alg(bp_cache::BeliefPropagationCache)
7272
Algorithm(
7373
"bp";
74-
tol=1e-12,
7574
verbose=false,
7675
maxiter=default_bp_maxiter(bp_cache),
7776
edge_sequence=default_bp_edge_sequence(bp_cache),
77+
tol=nothing,
7878
)
7979
end
8080
function default_message_update_alg(bp_cache::BeliefPropagationCache)
8181
Algorithm("contract"; normalize=true, sequence_alg="optimal")
8282
end
83+
function set_kwargs(alg::Algorithm"contract")
84+
normalize = get(alg.kwargs, :normalize, true)
85+
sequence_alg = get(alg.kwargs, :sequence_alg, "optimal")
86+
return Algorithm("contract"; normalize, sequence_alg)
87+
end
88+
function set_kwargs(alg::Algorithm"adapt_update")
89+
return Algorithm("adapt_update"; adapt=alg.kwargs.adapt, alg=set_kwargs(alg.kwargs.alg))
90+
end
91+
function set_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
92+
verbose = get(alg.kwargs, :verbose, false)
93+
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bp_cache))
94+
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bp_cache))
95+
tol = get(alg.kwargs, :tol, nothing)
96+
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol)
97+
end
8398

8499
function default_bp_maxiter(bp_cache::BeliefPropagationCache)
85100
return default_bp_maxiter(partitioned_graph(bp_cache))

0 commit comments

Comments
 (0)