-
Notifications
You must be signed in to change notification settings - Fork 23
AbstractBeliefPropagationCache #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mtfishman
merged 58 commits into
ITensor:main
from
JoeyT1994:AbstractBeliefPropagationCache
Mar 18, 2025
Merged
Changes from 56 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
c845947
Blah
JoeyT1994 90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 d096722
Fix bug
JoeyT1994 70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 921810d
Save First Run
JoeyT1994 c1c1f94
Preliminary ideas
JoeyT1994 c7ab747
Save
JoeyT1994 6f9dc32
First commit
JoeyT1994 08f574a
Sequence sorting
JoeyT1994 63b18ee
Test change
JoeyT1994 008b981
Stuff
JoeyT1994 4febcc3
Gauging
JoeyT1994 50c53d7
Workinggit add examples/test_boundarymps.jl
JoeyT1994 9b9b05d
Missing edges
JoeyT1994 a8c04d3
Simplify
JoeyT1994 d8ce9cb
Code clean
JoeyT1994 d095c6f
Improved nomenclature
JoeyT1994 6e8d7b8
Adding biorthogonalization feature
JoeyT1994 2ca6404
Biorthogonal algorithm
JoeyT1994 05ae49a
File Structure
JoeyT1994 9c5f6c4
Utils
JoeyT1994 12722da
Add test
JoeyT1994 f7c6beb
More tests
JoeyT1994 7d4465b
Testing
JoeyT1994 1b23fbb
Working Commit
JoeyT1994 338a41f
BoundaryMPS
JoeyT1994 62c87b4
Merge branch 'BoundaryMPS' of github.com:JoeyT1994/ITensorNetworks.jl…
JoeyT1994 9d64029
Revert BP Cache Code
JoeyT1994 dc46339
Rename kwarg
JoeyT1994 dd22863
examples/test_boundarymps.jl
JoeyT1994 807739f
Updated tests
JoeyT1994 ee078e5
Formatting
JoeyT1994 a028422
AbstractCache
JoeyT1994 81dcea4
Revert "AbstractCache"
JoeyT1994 5dc9677
Abstract Cahce
JoeyT1994 611bf18
Working Abstract Cache
JoeyT1994 6c7a2a8
Simplify initialization of messges
JoeyT1994 531c5af
Rm redundant file
JoeyT1994 93d905f
Merge branch 'AbstractCache' into BoundaryMPS
JoeyT1994 b11c2f8
Bug fix in test_apply
JoeyT1994 99b4929
Generic update interface for boundary mps and simple bp
JoeyT1994 19630aa
Unify update function
JoeyT1994 35fe94f
Generic naming support
JoeyT1994 4ae3b53
Updated expect, inner and environment interfaces
JoeyT1994 df7a4c5
Fix failing tests
JoeyT1994 e9e0336
Two site working
JoeyT1994 e98a056
Cleanup
JoeyT1994 5b9266e
Remove @show
JoeyT1994 0228aa1
Fix Bug in Expect
JoeyT1994 1fcf947
Fix expect test
JoeyT1994 8c2cc7e
Removed explicit BoundaryMPSCache
JoeyT1994 8a6d33a
Merge remote-tracking branch 'upstream/main' into BoundaryMPS
JoeyT1994 e6e86d6
Remove ref to BoundaryMPSFiles
JoeyT1994 33afe1c
Formatting
JoeyT1994 b31f139
Version Bump
JoeyT1994 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,292 @@ | ||
| using Graphs: IsDirected | ||
| using SplitApplyCombine: group | ||
| using LinearAlgebra: diag, dot | ||
| using ITensors: dir | ||
| using ITensorMPS: ITensorMPS | ||
| using NamedGraphs.PartitionedGraphs: | ||
| PartitionedGraphs, | ||
| PartitionedGraph, | ||
| PartitionVertex, | ||
| boundary_partitionedges, | ||
| partitionvertices, | ||
| partitionedges, | ||
| unpartitioned_graph | ||
| using SimpleTraits: SimpleTraits, Not, @traitfn | ||
| using NDTensors: NDTensors | ||
|
|
||
| abstract type AbstractBeliefPropagationCache end | ||
|
|
||
| function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) | ||
| sequence = optimal_contraction_sequence(contract_list) | ||
| updated_messages = contract(contract_list; sequence, kwargs...) | ||
| message_norm = norm(updated_messages) | ||
| if normalize && !iszero(message_norm) | ||
| updated_messages /= message_norm | ||
| end | ||
| return ITensor[updated_messages] | ||
| end | ||
|
|
||
| #TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages | ||
| function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) | ||
| lhs, rhs = contract(message_a), contract(message_b) | ||
| f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) | ||
| return 1 - f | ||
| end | ||
|
|
||
| default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] | ||
| default_messages(ptn::PartitionedGraph) = Dictionary() | ||
| @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing | ||
| @traitfn function default_bp_maxiter(g::::IsDirected) | ||
| return default_bp_maxiter(undirected_graph(underlying_graph(g))) | ||
| end | ||
| default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) | ||
| function default_partitioned_vertices(f::AbstractFormNetwork) | ||
| return group(v -> original_state_vertex(f, v), vertices(f)) | ||
| end | ||
|
|
||
| partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| messages(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| function default_message( | ||
| bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs... | ||
| ) | ||
| return not_implemented() | ||
| end | ||
| default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache) | ||
| return not_implemented() | ||
| end | ||
| function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache) | ||
| return not_implemented() | ||
| end | ||
| function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...) | ||
| return not_implemented() | ||
| end | ||
| function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...) | ||
| return not_implemented() | ||
| end | ||
| function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...) | ||
| return not_implemented() | ||
| end | ||
| partitions(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
| partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented() | ||
|
|
||
| function default_edge_sequence( | ||
| bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) | ||
| ) | ||
| return default_edge_sequence(Algorithm(alg), bpc) | ||
| end | ||
| function default_bp_maxiter( | ||
| bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) | ||
| ) | ||
| return default_bp_maxiter(Algorithm(alg), bpc) | ||
| end | ||
| function default_message_update_kwargs( | ||
| bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) | ||
| ) | ||
| return default_message_update_kwargs(Algorithm(alg), bpc) | ||
| end | ||
|
|
||
| function tensornetwork(bpc::AbstractBeliefPropagationCache) | ||
| return unpartitioned_graph(partitioned_tensornetwork(bpc)) | ||
| end | ||
|
|
||
| function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) | ||
| return ITensor[tensornetwork(bpc)[v] for v in verts] | ||
| end | ||
|
|
||
| function factors( | ||
| bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex} | ||
| ) | ||
| return factors(bpc, vertices(bpc, partition_verts)) | ||
| end | ||
|
|
||
| function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex) | ||
| return factors(bpc, [partition_vertex]) | ||
| end | ||
|
|
||
| function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...) | ||
| return map(pv -> region_scalar(bpc, pv; kwargs...), pvs) | ||
| end | ||
|
|
||
| function edge_scalars( | ||
| bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs... | ||
| ) | ||
| return map(pe -> region_scalar(bpc, pe; kwargs...), pes) | ||
| end | ||
|
|
||
| function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache) | ||
| return vertex_scalars(bpc), edge_scalars(bpc) | ||
| end | ||
|
|
||
| function incoming_messages( | ||
| bpc::AbstractBeliefPropagationCache, | ||
| partition_vertices::Vector{<:PartitionVertex}; | ||
| ignore_edges=(), | ||
| ) | ||
| bpes = boundary_partitionedges(bpc, partition_vertices; dir=:in) | ||
| ms = messages(bpc, setdiff(bpes, ignore_edges)) | ||
| return reduce(vcat, ms; init=ITensor[]) | ||
| end | ||
|
|
||
| function incoming_messages( | ||
| bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... | ||
| ) | ||
| return incoming_messages(bpc, [partition_vertex]; kwargs...) | ||
| end | ||
|
|
||
| #Forward from partitioned graph | ||
| for f in [ | ||
| :(PartitionedGraphs.partitioned_graph), | ||
| :(PartitionedGraphs.partitionedge), | ||
| :(PartitionedGraphs.partitionvertices), | ||
| :(PartitionedGraphs.vertices), | ||
| :(PartitionedGraphs.boundary_partitionedges), | ||
| :(ITensorMPS.linkinds), | ||
| ] | ||
| @eval begin | ||
| function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) | ||
| return $f(partitioned_tensornetwork(bpc), args...; kwargs...) | ||
| end | ||
| end | ||
| end | ||
|
|
||
| NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) | ||
|
|
||
| """ | ||
| Update the tensornetwork inside the cache | ||
| """ | ||
| function update_factors(bpc::AbstractBeliefPropagationCache, factors) | ||
| bpc = copy(bpc) | ||
| tn = tensornetwork(bpc) | ||
| for vertex in eachindex(factors) | ||
| # TODO: Add a check that this preserves the graph structure. | ||
| setindex_preserve_graph!(tn, factors[vertex], vertex) | ||
| end | ||
| return bpc | ||
| end | ||
|
|
||
| function update_factor(bpc, vertex, factor) | ||
| return update_factors(bpc, Dictionary([vertex], [factor])) | ||
| end | ||
|
|
||
| function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) | ||
| mts = messages(bpc) | ||
| return get(() -> default_message(bpc, edge; kwargs...), mts, edge) | ||
| end | ||
| function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) | ||
| return map(edge -> message(bpc, edge; kwargs...), edges) | ||
| end | ||
| function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message) | ||
| bpc = copy(bpc) | ||
| ms = messages(bpc) | ||
| set!(ms, pe, message) | ||
| return bpc | ||
| end | ||
|
|
||
| """ | ||
| Compute message tensor as product of incoming mts and local state | ||
| """ | ||
| function updated_message( | ||
| bpc::AbstractBeliefPropagationCache, | ||
| edge::PartitionEdge; | ||
| message_update_function=default_message_update, | ||
| message_update_function_kwargs=(;), | ||
| ) | ||
| vertex = src(edge) | ||
| incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) | ||
| state = factors(bpc, vertex) | ||
|
|
||
| return message_update_function( | ||
| ITensor[incoming_ms; state]; message_update_function_kwargs... | ||
| ) | ||
| end | ||
|
|
||
| function update( | ||
| alg::Algorithm"simplebp", | ||
| bpc::AbstractBeliefPropagationCache, | ||
| edge::PartitionEdge; | ||
| kwargs..., | ||
| ) | ||
| return set_message(bpc, edge, updated_message(bpc, edge; kwargs...)) | ||
| end | ||
|
|
||
| """ | ||
| Do a sequential update of the message tensors on `edges` | ||
| """ | ||
| function update( | ||
| alg::Algorithm, | ||
| bpc::AbstractBeliefPropagationCache, | ||
| edges::Vector; | ||
| (update_diff!)=nothing, | ||
| kwargs..., | ||
| ) | ||
| bpc = copy(bpc) | ||
| for e in edges | ||
| prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing | ||
| bpc = update(alg, bpc, e; kwargs...) | ||
| if !isnothing(update_diff!) | ||
| update_diff![] += message_diff(message(bpc, e), prev_message) | ||
| end | ||
| end | ||
| return bpc | ||
| end | ||
|
|
||
| """ | ||
| Do parallel updates between groups of edges of all message tensors | ||
| Currently we send the full message tensor data struct to update for each edge_group. But really we only need the | ||
| mts relevant to that group. | ||
| """ | ||
| function update( | ||
| alg::Algorithm, | ||
| bpc::AbstractBeliefPropagationCache, | ||
| edge_groups::Vector{<:Vector{<:PartitionEdge}}; | ||
| kwargs..., | ||
| ) | ||
| new_mts = copy(messages(bpc)) | ||
| for edges in edge_groups | ||
| bpc_t = update(alg, bpc, edges; kwargs...) | ||
| for e in edges | ||
| new_mts[e] = message(bpc_t, e) | ||
| end | ||
| end | ||
| return set_messages(bpc, new_mts) | ||
| end | ||
|
|
||
| """ | ||
| More generic interface for update, with default params | ||
| """ | ||
| function update( | ||
| alg::Algorithm, | ||
| bpc::AbstractBeliefPropagationCache; | ||
| edges=default_edge_sequence(alg, bpc), | ||
| maxiter=default_bp_maxiter(alg, bpc), | ||
| message_update_kwargs=default_message_update_kwargs(alg, bpc), | ||
| tol=nothing, | ||
| verbose=false, | ||
| ) | ||
| compute_error = !isnothing(tol) | ||
| if isnothing(maxiter) | ||
| error("You need to specify a number of iterations for BP!") | ||
| end | ||
| for i in 1:maxiter | ||
| diff = compute_error ? Ref(0.0) : nothing | ||
| bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...) | ||
| if compute_error && (diff.x / length(edges)) <= tol | ||
| if verbose | ||
| println("BP converged to desired precision after $i iterations.") | ||
| end | ||
| break | ||
| end | ||
| end | ||
| return bpc | ||
| end | ||
|
|
||
| function update( | ||
| bpc::AbstractBeliefPropagationCache; | ||
| alg::String=default_message_update_alg(bpc), | ||
| kwargs..., | ||
| ) | ||
| return update(Algorithm(alg), bpc; kwargs...) | ||
| end | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.