Skip to content

Commit 15a87a5

Browse files
committed
Improvements
1 parent 813b436 commit 15a87a5

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ function update(
341341
bpc,
342342
alg.kwargs.edge_sequence;
343343
(update_diff!)=diff,
344-
message_update_alg=set_kwargs(message_update_alg),
344+
message_update_alg=set_default_kwargs(message_update_alg),
345345
kwargs...,
346346
)
347347
if compute_error && (diff.x / length(edges)) <= alg.kwargs.tol
@@ -355,7 +355,7 @@ function update(
355355
end
356356

357357
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
358-
return update(set_kwargs(alg, bpc), bpc; kwargs...)
358+
return update(set_default_kwargs(alg, bpc), bpc; kwargs...)
359359
end
360360

361361
function rescale_messages(

src/caches/beliefpropagationcache.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using NamedGraphs.PartitionedGraphs:
1414
unpartitioned_graph,
1515
which_partition
1616
using SimpleTraits: SimpleTraits, Not, @traitfn
17-
using NDTensors: NDTensors
17+
using NDTensors: NDTensors, Algorithm
1818

1919
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
2020
return (; partitioned_vertices=default_partitioned_vertices(ψ))
@@ -80,19 +80,25 @@ end
8080
function default_message_update_alg(bp_cache::BeliefPropagationCache)
8181
return Algorithm("contract"; normalize=true, sequence_alg="optimal")
8282
end
83-
function set_kwargs(alg::Algorithm"contract")
84-
normalize = get(alg.kwargs, :normalize, true)
85-
sequence_alg = get(alg.kwargs, :sequence_alg, "optimal")
83+
default_normalize(::Algorithm"contract") = true
84+
default_sequence_alg(::Algorithm"contract") = "optimal"
85+
function set_default_kwargs(alg::Algorithm"contract")
86+
normalize = get(alg.kwargs, :normalize, default_normalize(alg))
87+
sequence_alg = get(alg.kwargs, :sequence_alg, default_sequence_alg(alg))
8688
return Algorithm("contract"; normalize, sequence_alg)
8789
end
88-
function set_kwargs(alg::Algorithm"adapt_update")
89-
return Algorithm("adapt_update"; adapt=alg.kwargs.adapt, alg=set_kwargs(alg.kwargs.alg))
90+
function set_default_kwargs(alg::Algorithm"adapt_update")
91+
return Algorithm(
92+
"adapt_update"; adapt=alg.kwargs.adapt, alg=set_default_kwargs(alg.kwargs.alg)
93+
)
9094
end
91-
function set_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
92-
verbose = get(alg.kwargs, :verbose, false)
95+
default_verbose(::Algorithm"bp") = false
96+
default_tol(::Algorithm"bp") = nothing
97+
function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
98+
verbose = get(alg.kwargs, :verbose, default_verbose(alg))
9399
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bp_cache))
94100
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bp_cache))
95-
tol = get(alg.kwargs, :tol, nothing)
101+
tol = get(alg.kwargs, :tol, default_tol(alg))
96102
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol)
97103
end
98104

test/test_apply.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ITensorNetworks:
1111
random_tensornetwork,
1212
siteinds,
1313
update
14-
using ITensors: ITensors, ITensor, inner, op
14+
using ITensors: ITensors, ITensor, Algorithm, inner, op
1515
using NamedGraphs.NamedGraphGenerators: named_grid
1616
using NamedGraphs.PartitionedGraphs: PartitionVertex
1717
using SplitApplyCombine: group

test/test_gauging.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using ITensorNetworks:
1010
siteinds,
1111
update
1212
using ITensors: diag_itensor, inds, inner
13-
using ITensors.NDTensors: vector
13+
using ITensors.NDTensors: Algorithm, vector
1414
using LinearAlgebra: diag
1515
using NamedGraphs.NamedGraphGenerators: named_grid
1616
using StableRNGs: StableRNG

0 commit comments

Comments
 (0)