@@ -25,7 +25,7 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525end
2626data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
2727
28- function message_update (alg:: Algorithm"contract" , contract_list:: Vector{ITensor} ; )
28+ function message_update (alg:: Algorithm"contract" , contract_list:: Vector{ITensor} )
2929 sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
3030 updated_messages = contract (contract_list; sequence)
3131 message_norm = norm (updated_messages)
@@ -35,10 +35,10 @@ function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor}
3535 return ITensor[updated_messages]
3636end
3737
38- function message_update (alg:: Algorithm"adapt_update" , contract_list:: Vector{ITensor} ; )
38+ function message_update (alg:: Algorithm"adapt_update" , contract_list:: Vector{ITensor} )
3939 adapted_contract_list = alg. kwargs. adapt .(contract_list)
4040 updated_messages = message_update (alg. kwargs. alg, adapted_contract_list)
41- dtype = datatype ( first ( contract_list) )
41+ dtype = mapreduce (datatype, promote_type, contract_list)
4242 return map (adapt (dtype), updated_messages)
4343end
4444
@@ -327,21 +327,25 @@ More generic interface for update, with default params
327327function update (
328328 alg:: Algorithm ,
329329 bpc:: AbstractBeliefPropagationCache ;
330- edges= alg. kwargs. edge_sequence,
331- tol= alg. kwargs. tol,
332- maxiter= alg. kwargs. maxiter,
333- verbose= alg. kwargs. verbose,
330+ message_update_alg= default_message_update_alg (bpc),
334331 kwargs... ,
335332)
336- compute_error = ! isnothing (tol)
337- if isnothing (maxiter)
333+ compute_error = ! isnothing (alg . kwargs . tol)
334+ if isnothing (alg . kwargs . maxiter)
338335 error (" You need to specify a number of iterations for BP!" )
339336 end
340- for i in 1 : maxiter
337+ for i in 1 : alg . kwargs . maxiter
341338 diff = compute_error ? Ref (0.0 ) : nothing
342- bpc = update (alg, bpc, edges; (update_diff!)= diff, kwargs... )
343- if compute_error && (diff. x / length (edges)) <= tol
344- if verbose
339+ bpc = update (
340+ alg,
341+ bpc,
342+ alg. kwargs. edge_sequence;
343+ (update_diff!)= diff,
344+ message_update_alg= set_kwargs (message_update_alg),
345+ kwargs... ,
346+ )
347+ if compute_error && (diff. x / length (edges)) <= alg. kwargs. tol
348+ if alg. kwargs. verbose
345349 println (" BP converged to desired precision after $i iterations." )
346350 end
347351 break
@@ -351,7 +355,7 @@ function update(
351355end
352356
353357function update (bpc:: AbstractBeliefPropagationCache ; alg= default_update_alg (bpc), kwargs... )
354- return update (Algorithm (alg), bpc; kwargs... )
358+ return update (set_kwargs (alg, bpc ), bpc; kwargs... )
355359end
356360
357361function rescale_messages (
0 commit comments