Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c845947
Blah
JoeyT1994 Sep 13, 2024
90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 Oct 17, 2024
86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Oct 17, 2024
6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 Oct 17, 2024
34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Nov 22, 2024
d096722
Fix bug
JoeyT1994 Nov 26, 2024
70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Dec 5, 2024
9d64fe8
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 19, 2025
9d6c1bc
File removed
JoeyT1994 Mar 19, 2025
ae17245
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 23, 2025
b648353
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 1, 2025
4e7d189
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 3, 2025
83c92b0
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 7, 2025
6f024ee
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 11, 2025
1106403
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 17, 2025
5641d32
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 30, 2025
59ef115
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 9, 2025
72bea86
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 13, 2025
bc14b33
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 20, 2025
fdd47c8
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 24, 2025
d8d3e52
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Jul 21, 2025
79b3aa2
Fix bug in delete_messages. Parameter restrictions in apply.jl
JoeyT1994 Jul 21, 2025
795eb63
Adapt for BPC and custom device contractions
JoeyT1994 Aug 15, 2025
2ebba6d
Fixes
JoeyT1994 Aug 15, 2025
ed9f1dc
Improve interface
JoeyT1994 Aug 15, 2025
9e4e3eb
Better defaults
JoeyT1994 Aug 17, 2025
f0b83c1
Restore Project.toml
JoeyT1994 Aug 17, 2025
756bae0
Merge remote-tracking branch 'upstream/main' into BP_GPU_and_Improvem…
JoeyT1994 Aug 17, 2025
b2d2784
Improvements
JoeyT1994 Aug 18, 2025
60d7bc8
Typo fix
JoeyT1994 Aug 18, 2025
1432a4d
Use f for variable name in map
JoeyT1994 Aug 18, 2025
6222386
set kwargs on algorithms
JoeyT1994 Aug 18, 2025
5b4a8fa
No default maxiter
JoeyT1994 Aug 18, 2025
813b436
Fix tests
JoeyT1994 Aug 18, 2025
15a87a5
Improvements
JoeyT1994 Aug 19, 2025
389f148
Simplify interface, fix tests
JoeyT1994 Aug 19, 2025
51184bf
Add default for updated_message
JoeyT1994 Aug 19, 2025
d962403
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 19, 2025
8850117
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 19, 2025
fc6572e
improvements
JoeyT1994 Aug 19, 2025
ea88e55
Merge branch 'BP_GPU_and_Improvements' of github.com:JoeyT1994/ITenso…
JoeyT1994 Aug 19, 2025
27ec78f
No kwargs being passed that arent needed
JoeyT1994 Aug 19, 2025
5ec0710
update_iteration
JoeyT1994 Aug 19, 2025
7b2c3d9
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 20, 2025
d58c3f4
Update docs and examples .toml
JoeyT1994 Aug 20, 2025
d918da1
Merge branch 'BP_GPU_and_Improvements' of github.com:JoeyT1994/ITenso…
JoeyT1994 Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.13.17"
version = "0.14.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
Expand Down Expand Up @@ -33,15 +34,13 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[extensions]
ITensorNetworksAdaptExt = "Adapt"
ITensorNetworksEinExprsExt = "EinExprs"
ITensorNetworksGraphsFlowsExt = "GraphsFlows"
ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
Expand Down Expand Up @@ -82,7 +81,6 @@ TupleTools = "1.4"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[compat]
Documenter = "1.10.0"
ITensorNetworks = "0.13.0"
ITensorNetworks = "0.14.0"
Literate = "2.20.1"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"

[compat]
ITensorNetworks = "0.13.2"
ITensorNetworks = "0.14.0"
14 changes: 0 additions & 14 deletions ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl

This file was deleted.

16 changes: 16 additions & 0 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Adapt: Adapt, adapt, adapt_structure
using DataGraphs:
DataGraphs, edge_data, underlying_graph, underlying_graph_type, vertex_data
using Dictionaries: Dictionary
Expand Down Expand Up @@ -86,6 +87,10 @@ function DataGraphs.underlying_graph_type(G::Type{<:AbstractITensorNetwork})
return underlying_graph_type(data_graph_type(G))
end

function ITensors.datatype(tn::AbstractITensorNetwork)
return mapreduce(v -> datatype(tn[v]), promote_type, vertices(tn))
end

# AbstractDataGraphs overloads
function DataGraphs.vertex_data(graph::AbstractITensorNetwork, args...)
return vertex_data(data_graph(graph), args...)
Expand All @@ -102,6 +107,17 @@ function NamedGraphs.ordered_vertices(tn::AbstractITensorNetwork)
return NamedGraphs.ordered_vertices(underlying_graph(tn))
end

function Adapt.adapt_structure(to, tn::AbstractITensorNetwork)
# TODO: Define and use:
#
# @preserve_graph map_vertex_data(adapt(to), tn)
#
# or just:
#
# @preserve_graph map(adapt(to), tn)
return map_vertex_data_preserve_graph(adapt(to), tn)
end

#
# Iteration
#
Expand Down
159 changes: 89 additions & 70 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Adapt: Adapt, adapt, adapt_structure
using Graphs: Graphs, IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag, dot
Expand All @@ -24,24 +25,20 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
end
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))

function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
sequence = contraction_sequence(contract_list; alg="optimal")
updated_messages = contract(contract_list; sequence, kwargs...)
message_norm = norm(updated_messages)
if normalize && !iszero(message_norm)
updated_messages /= message_norm
end
return ITensor[updated_messages]
end

#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
return 1 - f
end

default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
function default_message(datatype::Type{<:AbstractArray}, inds_e)
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
end

function default_message(elt::Type{<:Number}, inds_e)
return default_message(Vector{elt}, inds_e)
end
default_messages(ptn::PartitionedGraph) = Dictionary()
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
Expand All @@ -59,15 +56,13 @@ function default_message(
)
return not_implemented()
end
default_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
return not_implemented()
end
Expand All @@ -80,21 +75,8 @@ end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_edge_sequence(Algorithm(alg), bpc)
end
function default_bp_maxiter(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_bp_maxiter(Algorithm(alg), bpc)
end
function default_message_update_kwargs(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_message_update_kwargs(Algorithm(alg), bpc)
end
default_bp_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented()

function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
Expand Down Expand Up @@ -144,6 +126,36 @@ function incoming_messages(
return incoming_messages(bpc, [partition_vertex]; kwargs...)
end

#Adapt interface for changing device
function map_messages(
f, bpc::AbstractBeliefPropagationCache, pes=collect(keys(messages(bpc)))
)
bpc = copy(bpc)
for pe in pes
set_message!(bpc, pe, f.(message(bpc, pe)))
end
return bpc
end
function map_factors(f, bpc::AbstractBeliefPropagationCache, vs=vertices(bpc))
bpc = copy(bpc)
for v in vs
@preserve_graph bpc[v] = f(bpc[v])
end
return bpc
end
function adapt_messages(to, bpc::AbstractBeliefPropagationCache, args...)
return map_messages(adapt(to), bpc, args...)
end
function adapt_factors(to, bpc::AbstractBeliefPropagationCache, args...)
return map_factors(adapt(to), bpc, args...)
end

function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
bpc = adapt_messages(to, bpc)
bpc = adapt_factors(to, bpc)
return bpc
end

#Forward from partitioned graph
for f in [
:(PartitionedGraphs.partitioned_graph),
Expand Down Expand Up @@ -234,44 +246,63 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
return delete_messages(bpc, [pe])
end

"""
Compute message tensor as product of incoming mts and local state
"""
function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
message_update_function=default_message_update,
message_update_function_kwargs=(;),
alg::Algorithm"contract", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factors(bpc, vertex)
contract_list = ITensor[incoming_ms; state]
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
updated_messages = contract(contract_list; sequence)
message_norm = norm(updated_messages)
if alg.kwargs.normalize && !iszero(message_norm)
updated_messages /= message_norm
end
return ITensor[updated_messages]
end

return message_update_function(
ITensor[incoming_ms; state]; message_update_function_kwargs...
function updated_message(
alg::Algorithm"adapt_update", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
)
incoming_pes = setdiff(
boundary_partitionedges(bpc, [src(edge)]; dir=:in), [reverse(edge)]
)
adapted_bpc = adapt_messages(alg.kwargs.adapt, bpc, incoming_pes)
adapted_bpc = adapt_factors(alg.kwargs.adapt, bpc, vertices(bpc, src(edge)))
updated_messages = updated_message(alg.kwargs.alg, adapted_bpc, edge)
dtype = mapreduce(datatype, promote_type, message(bpc, edge))
return map(adapt(dtype), updated_messages)
end

function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
alg=default_message_update_alg(bpc),
kwargs...,
)
return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bpc, edge)
end

function update(
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
function update_message(
message_update_alg::Algorithm, bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
)
return set_message(bpc, edge, updated_message(bpc, edge; kwargs...))
return set_message(bpc, edge, updated_message(message_update_alg, bpc, edge))
end

"""
Do a sequential update of the message tensors on `edges`
"""
function update(
alg::Algorithm,
function update_iteration(
alg::Algorithm"bp",
bpc::AbstractBeliefPropagationCache,
edges::Vector;
(update_diff!)=nothing,
kwargs...,
)
bpc = copy(bpc)
for e in edges
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
bpc = update(alg, bpc, e; kwargs...)
bpc = update_message(alg.kwargs.message_update_alg, bpc, e)
if !isnothing(update_diff!)
update_diff![] += message_diff(message(bpc, e), prev_message)
end
Expand All @@ -284,15 +315,15 @@ Do parallel updates between groups of edges of all message tensors
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
mts relevant to that group.
"""
function update(
alg::Algorithm,
function update_iteration(
alg::Algorithm"bp",
bpc::AbstractBeliefPropagationCache,
edge_groups::Vector{<:Vector{<:PartitionEdge}};
kwargs...,
(update_diff!)=nothing,
)
new_mts = empty(messages(bpc))
for edges in edge_groups
bpc_t = update(alg, bpc, edges; kwargs...)
bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!))
for e in edges
set!(new_mts, e, message(bpc_t, e))
end
Expand All @@ -303,24 +334,16 @@ end
"""
More generic interface for update, with default params
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache;
edges=default_edge_sequence(alg, bpc),
maxiter=default_bp_maxiter(alg, bpc),
message_update_kwargs=default_message_update_kwargs(alg, bpc),
tol=nothing,
verbose=false,
)
compute_error = !isnothing(tol)
if isnothing(maxiter)
function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache)
compute_error = !isnothing(alg.kwargs.tol)
if isnothing(alg.kwargs.maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
for i in 1:alg.kwargs.maxiter
diff = compute_error ? Ref(0.0) : nothing
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
bpc = update_iteration(alg, bpc, alg.kwargs.edge_sequence; (update_diff!)=diff)
if compute_error && (diff.x / length(alg.kwargs.edge_sequence)) <= alg.kwargs.tol
if alg.kwargs.verbose
println("BP converged to desired precision after $i iterations.")
end
break
Expand All @@ -329,12 +352,8 @@ function update(
return bpc
end

function update(
bpc::AbstractBeliefPropagationCache;
alg::String=default_message_update_alg(bpc),
kwargs...,
)
return update(Algorithm(alg), bpc; kwargs...)
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc)
end

function rescale_messages(
Expand Down
Loading
Loading