diff --git a/Project.toml b/Project.toml index f08f14b6..49f03674 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.11.26" +version = "0.11.27" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 28da183a..57d13e88 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -28,6 +28,7 @@ include("edge_sequences.jl") include("formnetworks/abstractformnetwork.jl") include("formnetworks/bilinearformnetwork.jl") include("formnetworks/quadraticformnetwork.jl") +include("caches/abstractbeliefpropagationcache.jl") include("caches/beliefpropagationcache.jl") include("contraction_tree_to_graph.jl") include("gauging.jl") diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl new file mode 100644 index 00000000..01c90c04 --- /dev/null +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -0,0 +1,289 @@ +using Graphs: IsDirected +using SplitApplyCombine: group +using LinearAlgebra: diag, dot +using ITensors: dir +using ITensorMPS: ITensorMPS +using NamedGraphs.PartitionedGraphs: + PartitionedGraphs, + PartitionedGraph, + PartitionVertex, + boundary_partitionedges, + partitionvertices, + partitionedges, + unpartitioned_graph +using SimpleTraits: SimpleTraits, Not, @traitfn +using NDTensors: NDTensors + +abstract type AbstractBeliefPropagationCache end + +function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) + sequence = optimal_contraction_sequence(contract_list) + updated_messages = contract(contract_list; sequence, kwargs...) + message_norm = norm(updated_messages) + if normalize && !iszero(message_norm) + updated_messages /= message_norm + end + return ITensor[updated_messages] +end + +#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages +function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) + lhs, rhs = contract(message_a), contract(message_b) + f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) + return 1 - f +end + +default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] +default_messages(ptn::PartitionedGraph) = Dictionary() +@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing +@traitfn function default_bp_maxiter(g::::IsDirected) + return default_bp_maxiter(undirected_graph(underlying_graph(g))) +end +default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +function default_partitioned_vertices(f::AbstractFormNetwork) + return group(v -> original_state_vertex(f, v), vertices(f)) +end + +partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() +messages(bpc::AbstractBeliefPropagationCache) = not_implemented() +function default_message( + bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs... +) + return not_implemented() +end +default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented() +Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented() +default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented() +function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache) + return not_implemented() +end +function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache) + return not_implemented() +end +function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...) + return not_implemented() +end +function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...) + return not_implemented() +end +function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...) + return not_implemented() +end +partitions(bpc::AbstractBeliefPropagationCache) = not_implemented() +partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented() + +function default_edge_sequence( + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) +) + return default_edge_sequence(Algorithm(alg), bpc) +end +function default_bp_maxiter( + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) +) + return default_bp_maxiter(Algorithm(alg), bpc) +end +function default_message_update_kwargs( + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) +) + return default_message_update_kwargs(Algorithm(alg), bpc) +end + +function tensornetwork(bpc::AbstractBeliefPropagationCache) + return unpartitioned_graph(partitioned_tensornetwork(bpc)) +end + +function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) + return ITensor[tensornetwork(bpc)[v] for v in verts] +end + +function factors( + bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex} +) + return factors(bpc, vertices(bpc, partition_verts)) +end + +function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex) + return factors(bpc, [partition_vertex]) +end + +function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...) + return map(pv -> region_scalar(bpc, pv; kwargs...), pvs) +end + +function edge_scalars( + bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs... +) + return map(pe -> region_scalar(bpc, pe; kwargs...), pes) +end + +function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache) + return vertex_scalars(bpc), edge_scalars(bpc) +end + +function incoming_messages( + bpc::AbstractBeliefPropagationCache, + partition_vertices::Vector{<:PartitionVertex}; + ignore_edges=(), +) + bpes = boundary_partitionedges(bpc, partition_vertices; dir=:in) + ms = messages(bpc, setdiff(bpes, ignore_edges)) + return reduce(vcat, ms; init=ITensor[]) +end + +function incoming_messages( + bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... +) + return incoming_messages(bpc, [partition_vertex]; kwargs...) +end + +#Forward from partitioned graph +for f in [ + :(PartitionedGraphs.partitioned_graph), + :(PartitionedGraphs.partitionedge), + :(PartitionedGraphs.partitionvertices), + :(PartitionedGraphs.vertices), + :(PartitionedGraphs.boundary_partitionedges), + :(ITensorMPS.linkinds), +] + @eval begin + function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + return $f(partitioned_tensornetwork(bpc), args...; kwargs...) + end + end +end + +NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) + +""" +Update the tensornetwork inside the cache +""" +function update_factors(bpc::AbstractBeliefPropagationCache, factors) + bpc = copy(bpc) + tn = tensornetwork(bpc) + for vertex in eachindex(factors) + # TODO: Add a check that this preserves the graph structure. + setindex_preserve_graph!(tn, factors[vertex], vertex) + end + return bpc +end + +function update_factor(bpc, vertex, factor) + return update_factors(bpc, Dictionary([vertex], [factor])) +end + +function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) + mts = messages(bpc) + return get(() -> default_message(bpc, edge; kwargs...), mts, edge) +end +function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) + return map(edge -> message(bpc, edge; kwargs...), edges) +end +function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message) + bpc = copy(bpc) + ms = messages(bpc) + set!(ms, pe, message) + return bpc +end + +""" +Compute message tensor as product of incoming mts and local state +""" +function updated_message( + bpc::AbstractBeliefPropagationCache, + edge::PartitionEdge; + message_update_function=default_message_update, + message_update_function_kwargs=(;), +) + vertex = src(edge) + incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) + state = factors(bpc, vertex) + + return message_update_function( + ITensor[incoming_ms; state]; message_update_function_kwargs... + ) +end + +function update( + alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs... +) + return set_message(bpc, edge, updated_message(bpc, edge; kwargs...)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update( + alg::Algorithm, + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!)=nothing, + kwargs..., +) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + bpc = update(alg, bpc, e; kwargs...) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update( + alg::Algorithm, + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:PartitionEdge}}; + kwargs..., +) + new_mts = copy(messages(bpc)) + for edges in edge_groups + bpc_t = update(alg, bpc, edges; kwargs...) + for e in edges + new_mts[e] = message(bpc_t, e) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update( + alg::Algorithm, + bpc::AbstractBeliefPropagationCache; + edges=default_edge_sequence(alg, bpc), + maxiter=default_bp_maxiter(alg, bpc), + message_update_kwargs=default_message_update_kwargs(alg, bpc), + tol=nothing, + verbose=false, +) + compute_error = !isnothing(tol) + if isnothing(maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...) + if compute_error && (diff.x / length(edges)) <= tol + if verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update( + bpc::AbstractBeliefPropagationCache; + alg::String=default_message_update_alg(bpc), + kwargs..., +) + return update(Algorithm(alg), bpc; kwargs...) +end diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index d5886ec7..31d43e9b 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -14,57 +14,33 @@ using NamedGraphs.PartitionedGraphs: using SimpleTraits: SimpleTraits, Not, @traitfn using NDTensors: NDTensors -default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] -default_messages(ptn::PartitionedGraph) = Dictionary() -function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) - sequence = optimal_contraction_sequence(contract_list) - updated_messages = contract(contract_list; sequence, kwargs...) - message_norm = norm(updated_messages) - if normalize && !iszero(message_norm) - updated_messages /= message_norm - end - return ITensor[updated_messages] -end -@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing -@traitfn function default_bp_maxiter(g::::IsDirected) - return default_bp_maxiter(undirected_graph(underlying_graph(g))) -end -default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) -function default_partitioned_vertices(f::AbstractFormNetwork) - return group(v -> original_state_vertex(f, v), vertices(f)) -end -default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8) function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork) return (; partitioned_vertices=default_partitioned_vertices(ψ)) end -#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages -function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) - lhs, rhs = contract(message_a), contract(message_b) - f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) - return 1 - f +function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGraph) + return (;) end -struct BeliefPropagationCache{PTN,MTS,DM} +struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache partitioned_tensornetwork::PTN messages::MTS - default_message::DM end #Constructors... -function BeliefPropagationCache( - ptn::PartitionedGraph; messages=default_messages(ptn), default_message=default_message -) - return BeliefPropagationCache(ptn, messages, default_message) +function BeliefPropagationCache(ptn::PartitionedGraph; messages=default_messages(ptn)) + return BeliefPropagationCache(ptn, messages) end -function BeliefPropagationCache(tn, partitioned_vertices; kwargs...) +function BeliefPropagationCache(tn::AbstractITensorNetwork, partitioned_vertices; kwargs...) ptn = PartitionedGraph(tn, partitioned_vertices) return BeliefPropagationCache(ptn; kwargs...) end function BeliefPropagationCache( - tn; partitioned_vertices=default_partitioned_vertices(tn), kwargs... + tn::AbstractITensorNetwork; + partitioned_vertices=default_partitioned_vertices(tn), + kwargs..., ) return BeliefPropagationCache(tn, partitioned_vertices; kwargs...) end @@ -72,204 +48,50 @@ end function cache(alg::Algorithm"bp", tn; kwargs...) return BeliefPropagationCache(tn; kwargs...) end +default_cache_update_kwargs(alg::Algorithm"bp") = (; maxiter=25, tol=1e-8) function partitioned_tensornetwork(bp_cache::BeliefPropagationCache) return bp_cache.partitioned_tensornetwork end -messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -default_message(bp_cache::BeliefPropagationCache) = bp_cache.default_message -function tensornetwork(bp_cache::BeliefPropagationCache) - return unpartitioned_graph(partitioned_tensornetwork(bp_cache)) -end -#Forward from partitioned graph -for f in [ - :(PartitionedGraphs.partitioned_graph), - :(PartitionedGraphs.partitionedge), - :(PartitionedGraphs.partitionvertices), - :(PartitionedGraphs.vertices), - :(PartitionedGraphs.boundary_partitionedges), - :(ITensorMPS.linkinds), -] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(partitioned_tensornetwork(bp_cache), args...; kwargs...) - end - end -end - -NDTensors.scalartype(bp_cache) = scalartype(tensornetwork(bp_cache)) +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) - return default_message(bp_cache)(scalartype(bp_cache), linkinds(bp_cache, edge)) -end - -function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) - mts = messages(bp_cache) - return get(() -> default_message(bp_cache, edge), mts, edge) -end -function messages(bp_cache::BeliefPropagationCache, edges; kwargs...) - return map(edge -> message(bp_cache, edge; kwargs...), edges) + return default_message(scalartype(bp_cache), linkinds(bp_cache, edge)) end function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache( - copy(partitioned_tensornetwork(bp_cache)), - copy(messages(bp_cache)), - default_message(bp_cache), + copy(partitioned_tensornetwork(bp_cache)), copy(messages(bp_cache)) ) end -function default_bp_maxiter(bp_cache::BeliefPropagationCache) +default_message_update_alg(bp_cache::BeliefPropagationCache) = "bp" + +function default_bp_maxiter(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) return default_bp_maxiter(partitioned_graph(bp_cache)) end -function default_edge_sequence(bp_cache::BeliefPropagationCache) +function default_edge_sequence(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) return default_edge_sequence(partitioned_tensornetwork(bp_cache)) end - -function set_messages(cache::BeliefPropagationCache, messages) - return BeliefPropagationCache( - partitioned_tensornetwork(cache), messages, default_message(cache) - ) -end - -function environment( - bp_cache::BeliefPropagationCache, - partition_vertices::Vector{<:PartitionVertex}; - ignore_edges=(), -) - bpes = boundary_partitionedges(bp_cache, partition_vertices; dir=:in) - ms = messages(bp_cache, setdiff(bpes, ignore_edges)) - return reduce(vcat, ms; init=ITensor[]) -end - -function environment( - bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... -) - return environment(bp_cache, [partition_vertex]; kwargs...) -end - -function environment(bp_cache::BeliefPropagationCache, verts::Vector) - partition_verts = partitionvertices(bp_cache, verts) - messages = environment(bp_cache, partition_verts) - central_tensors = factors(bp_cache, setdiff(vertices(bp_cache, partition_verts), verts)) - return vcat(messages, central_tensors) -end - -function factors(bp_cache::BeliefPropagationCache, verts::Vector) - return ITensor[tensornetwork(bp_cache)[v] for v in verts] -end - -function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex) - return factors(bp_cache, vertices(bp_cache, vertex)) -end - -""" -Compute message tensor as product of incoming mts and local state -""" -function update_message( - bp_cache::BeliefPropagationCache, - edge::PartitionEdge; - message_update=default_message_update, - message_update_kwargs=(;), +function default_message_update_kwargs( + alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache ) - vertex = src(edge) - messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) - state = factor(bp_cache, vertex) - - return message_update(ITensor[messages; state]; message_update_kwargs...) + return (;) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update( - bp_cache::BeliefPropagationCache, - edges::Vector{<:PartitionEdge}; - (update_diff!)=nothing, - kwargs..., -) - bp_cache_updated = copy(bp_cache) - mts = messages(bp_cache_updated) - for e in edges - set!(mts, e, update_message(bp_cache_updated, e; kwargs...)) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bp_cache, e), mts[e]) - end - end - return bp_cache_updated -end +partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc)) +partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc)) -""" -Update the message tensor on a single edge -""" -function update(bp_cache::BeliefPropagationCache, edge::PartitionEdge; kwargs...) - return update(bp_cache, [edge]; kwargs...) -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update( - bp_cache::BeliefPropagationCache, - edge_groups::Vector{<:Vector{<:PartitionEdge}}; - kwargs..., -) - new_mts = copy(messages(bp_cache)) - for edges in edge_groups - bp_cache_t = update(bp_cache, edges; kwargs...) - for e in edges - new_mts[e] = message(bp_cache_t, e) - end - end - return set_messages(bp_cache, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update( - bp_cache::BeliefPropagationCache; - edges=default_edge_sequence(bp_cache), - maxiter=default_bp_maxiter(bp_cache), - tol=nothing, - verbose=false, - kwargs..., -) - compute_error = !isnothing(tol) - if isnothing(maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:maxiter - diff = compute_error ? Ref(0.0) : nothing - bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...) - if compute_error && (diff.x / length(edges)) <= tol - if verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bp_cache -end - -""" -Update the tensornetwork inside the cache -""" -function update_factors(bp_cache::BeliefPropagationCache, factors) - bp_cache = copy(bp_cache) - tn = tensornetwork(bp_cache) - for vertex in eachindex(factors) - # TODO: Add a check that this preserves the graph structure. - setindex_preserve_graph!(tn, factors[vertex], vertex) - end - return bp_cache +function set_messages(cache::BeliefPropagationCache, messages) + return BeliefPropagationCache(partitioned_tensornetwork(cache), messages) end -function update_factor(bp_cache, vertex, factor) - return update_factors(bp_cache, Dictionary([vertex], [factor])) +function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...) + partition_verts = partitionvertices(bpc, verts) + messages = incoming_messages(bpc, partition_verts; kwargs...) + central_tensors = factors(bpc, setdiff(vertices(bpc, partition_verts), verts)) + return vcat(messages, central_tensors) end function region_scalar( @@ -277,8 +99,8 @@ function region_scalar( pv::PartitionVertex; contract_kwargs=(; sequence="automatic"), ) - incoming_mts = environment(bp_cache, [pv]) - local_state = factor(bp_cache, pv) + incoming_mts = incoming_messages(bp_cache, [pv]) + local_state = factors(bp_cache, pv) return contract(vcat(incoming_mts, local_state); contract_kwargs...)[] end @@ -291,23 +113,3 @@ function region_scalar( vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))); contract_kwargs... )[] end - -function vertex_scalars( - bp_cache::BeliefPropagationCache, - pvs=partitionvertices(partitioned_tensornetwork(bp_cache)); - kwargs..., -) - return map(pv -> region_scalar(bp_cache, pv; kwargs...), pvs) -end - -function edge_scalars( - bp_cache::BeliefPropagationCache, - pes=partitionedges(partitioned_tensornetwork(bp_cache)); - kwargs..., -) - return map(pe -> region_scalar(bp_cache, pe; kwargs...), pes) -end - -function scalar_factors_quotient(bp_cache::BeliefPropagationCache) - return vertex_scalars(bp_cache), edge_scalars(bp_cache) -end diff --git a/src/contract.jl b/src/contract.jl index 4adb0e10..70036566 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -30,7 +30,7 @@ function NDTensors.contract( return contract_approx(alg, tn, output_structure; kwargs...) end -function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...) +function ITensors.scalar(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...) return contract(alg, tn; kwargs...)[] end @@ -54,7 +54,7 @@ function logscalar( (cache!)=nothing, cache_construction_kwargs=default_cache_construction_kwargs(alg, tn), update_cache=isnothing(cache!), - cache_update_kwargs=default_cache_update_kwargs(cache!), + cache_update_kwargs=default_cache_update_kwargs(alg), ) if isnothing(cache!) cache! = Ref(cache(alg, tn; cache_construction_kwargs...)) @@ -77,6 +77,6 @@ function logscalar( return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end -function ITensors.scalar(alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...) +function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...) return exp(logscalar(alg, tn; kwargs...)) end diff --git a/src/environment.jl b/src/environment.jl index f3c424c0..cec13f21 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -1,7 +1,7 @@ using ITensors: contract using NamedGraphs.PartitionedGraphs: PartitionedGraph -default_environment_algorithm() = "exact" +default_environment_algorithm() = "bp" function environment( tn::AbstractITensorNetwork, @@ -19,15 +19,16 @@ function environment( end function environment( - ::Algorithm"bp", + alg::Algorithm, ptn::PartitionedGraph, vertices::Vector; (cache!)=nothing, update_cache=isnothing(cache!), - cache_update_kwargs=default_cache_update_kwargs(cache!), + cache_construction_kwargs=default_cache_construction_kwargs(alg, ptn), + cache_update_kwargs=default_cache_update_kwargs(alg), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(ptn)) + cache! = Ref(cache(alg, ptn; cache_construction_kwargs...)) end if update_cache @@ -38,7 +39,7 @@ function environment( end function environment( - alg::Algorithm"bp", + alg::Algorithm, tn::AbstractITensorNetwork, vertices::Vector; partitioned_vertices=default_partitioned_vertices(tn), diff --git a/src/expect.jl b/src/expect.jl index e1d46a9f..69ce42ae 100644 --- a/src/expect.jl +++ b/src/expect.jl @@ -24,14 +24,13 @@ function ITensorMPS.expect( ops; (cache!)=nothing, update_cache=isnothing(cache!), - cache_update_kwargs=default_cache_update_kwargs(cache!), - cache_construction_function=tn -> - cache(alg, tn; default_cache_construction_kwargs(alg, tn)...), + cache_update_kwargs=default_cache_update_kwargs(alg), + cache_construction_kwargs=default_cache_construction_kwargs(alg, inner_network(ψ, ψ)), kwargs..., ) ψIψ = inner_network(ψ, ψ) if isnothing(cache!) - cache! = Ref(cache_construction_function(ψIψ)) + cache! = Ref(cache(alg, ψIψ; cache_construction_kwargs...)) end if update_cache diff --git a/src/gauging.jl b/src/gauging.jl index 2ad9d0f4..49384ca6 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -128,7 +128,7 @@ function VidalITensorNetwork( ψ::ITensorNetwork; (cache!)=nothing, update_cache=isnothing(cache!), - cache_update_kwargs=default_cache_update_kwargs(cache!), + cache_update_kwargs=default_cache_update_kwargs(Algorithm("bp")), kwargs..., ) if isnothing(cache!) diff --git a/src/inner.jl b/src/inner.jl index 43486703..6f1cdc86 100644 --- a/src/inner.jl +++ b/src/inner.jl @@ -90,7 +90,7 @@ function ITensorMPS.loginner( end function ITensorMPS.loginner( - alg::Algorithm"bp", + alg::Algorithm, ϕ::AbstractITensorNetwork, ψ::AbstractITensorNetwork; dual_link_index_map=sim, @@ -101,7 +101,7 @@ function ITensorMPS.loginner( end function ITensorMPS.loginner( - alg::Algorithm"bp", + alg::Algorithm, ϕ::AbstractITensorNetwork, A::AbstractITensorNetwork, ψ::AbstractITensorNetwork; @@ -113,7 +113,7 @@ function ITensorMPS.loginner( end function ITensors.inner( - alg::Algorithm"bp", + alg::Algorithm, ϕ::AbstractITensorNetwork, ψ::AbstractITensorNetwork; dual_link_index_map=sim, @@ -124,7 +124,7 @@ function ITensors.inner( end function ITensors.inner( - alg::Algorithm"bp", + alg::Algorithm, ϕ::AbstractITensorNetwork, A::AbstractITensorNetwork, ψ::AbstractITensorNetwork; diff --git a/test/test_apply.jl b/test/test_apply.jl index 18d4f590..0bbc580a 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -30,7 +30,7 @@ using Test: @test, @testset #Simple Belief Propagation Grouping bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsSBP = environment(bp_cache, PartitionVertex.([v1, v2])) + envsSBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")]) ψv = VidalITensorNetwork(ψ) #This grouping will correspond to calculating the environments exactly (each column of the grid is a partition) bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ))) diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index a151e4d1..040d701f 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -23,7 +23,7 @@ using ITensorNetworks: tensornetwork, update, update_factor, - update_message, + updated_message, message_diff using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random_itensor using ITensorNetworks.ModelNetworks: ModelNetworks @@ -51,7 +51,7 @@ using Test: @test, @testset bpc = update(bpc; maxiter=25, tol=eps(real(elt))) #Test messages are converged for pe in partitionedges(partitioned_tensornetwork(bpc)) - @test message_diff(update_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt)) + @test message_diff(updated_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt)) @test eltype(only(message(bpc, pe))) == elt end #Test updating the underlying tensornetwork in the cache diff --git a/test/test_expect.jl b/test/test_expect.jl index 75ed8504..c654138e 100644 --- a/test/test_expect.jl +++ b/test/test_expect.jl @@ -29,11 +29,13 @@ using Test: @test, @testset s = siteinds("S=1/2", g) rng = StableRNG(1234) ψ = random_tensornetwork(rng, s; link_space=χ) - cache_construction_function = - f -> BeliefPropagationCache( - f; partitioned_vertices=group(v -> (original_state_vertex(f, v)[1]), vertices(f)) - ) - sz_bp = expect(ψ, "Sz"; alg="bp", cache_construction_function) + quadratic_form_vertices = reduce( + vcat, [[(v, "ket"), (v, "bra"), (v, "operator")] for v in vertices(ψ)] + ) + cache_construction_kwargs = (; + partitioned_vertices=group(v -> first(first(v)), quadratic_form_vertices) + ) + sz_bp = expect(ψ, "Sz"; alg="bp", cache_construction_kwargs) sz_exact = expect(ψ, "Sz"; alg="exact") @test sz_bp ≈ sz_exact diff --git a/test/test_forms.jl b/test/test_forms.jl index a58822e5..7e6ada8b 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -62,7 +62,7 @@ using Test: @test, @testset @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(environment(qf, state_vertices(qf, [v]))) + ∂qf_∂v = only(environment(qf, state_vertices(qf, [v]); alg="exact")) @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) ∂qf_∂v_bp = environment(qf, state_vertices(qf, [v]); alg="bp", update_cache=false)