Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c845947
Blah
JoeyT1994 Sep 13, 2024
90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 Oct 17, 2024
86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Oct 17, 2024
6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 Oct 17, 2024
34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Nov 22, 2024
d096722
Fix bug
JoeyT1994 Nov 26, 2024
70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Dec 5, 2024
9d64fe8
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 19, 2025
9d6c1bc
File removed
JoeyT1994 Mar 19, 2025
ae17245
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 23, 2025
b648353
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 1, 2025
4e7d189
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 3, 2025
83c92b0
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 7, 2025
6f024ee
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 11, 2025
1106403
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 17, 2025
5641d32
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 30, 2025
59ef115
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 9, 2025
72bea86
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 13, 2025
bc14b33
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 20, 2025
fdd47c8
Merge remote-tracking branch 'upstream/main'
JoeyT1994 May 24, 2025
d8d3e52
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Jul 21, 2025
79b3aa2
Fix bug in delete_messages. Parameter restrictions in apply.jl
JoeyT1994 Jul 21, 2025
795eb63
Adapt for BPC and custom device contractions
JoeyT1994 Aug 15, 2025
2ebba6d
Fixes
JoeyT1994 Aug 15, 2025
ed9f1dc
Improve interface
JoeyT1994 Aug 15, 2025
9e4e3eb
Better defaults
JoeyT1994 Aug 17, 2025
f0b83c1
Restore Project.toml
JoeyT1994 Aug 17, 2025
756bae0
Merge remote-tracking branch 'upstream/main' into BP_GPU_and_Improvem…
JoeyT1994 Aug 17, 2025
b2d2784
Improvements
JoeyT1994 Aug 18, 2025
60d7bc8
Typo fix
JoeyT1994 Aug 18, 2025
1432a4d
Use f for variable name in map
JoeyT1994 Aug 18, 2025
6222386
set kwargs on algorithms
JoeyT1994 Aug 18, 2025
5b4a8fa
No default maxiter
JoeyT1994 Aug 18, 2025
813b436
Fix tests
JoeyT1994 Aug 18, 2025
15a87a5
Improvements
JoeyT1994 Aug 19, 2025
389f148
Simplify interface, fix tests
JoeyT1994 Aug 19, 2025
51184bf
Add default for updated_message
JoeyT1994 Aug 19, 2025
d962403
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 19, 2025
8850117
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 19, 2025
fc6572e
improvements
JoeyT1994 Aug 19, 2025
ea88e55
Merge branch 'BP_GPU_and_Improvements' of github.com:JoeyT1994/ITenso…
JoeyT1994 Aug 19, 2025
27ec78f
No kwargs being passed that arent needed
JoeyT1994 Aug 19, 2025
5ec0710
update_iteration
JoeyT1994 Aug 19, 2025
7b2c3d9
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Aug 20, 2025
d58c3f4
Update docs and examples .toml
JoeyT1994 Aug 20, 2025
d918da1
Merge branch 'BP_GPU_and_Improvements' of github.com:JoeyT1994/ITenso…
JoeyT1994 Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.13.17"
version = "0.14.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
Expand Down Expand Up @@ -33,15 +34,13 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[extensions]
ITensorNetworksAdaptExt = "Adapt"
ITensorNetworksEinExprsExt = "EinExprs"
ITensorNetworksGraphsFlowsExt = "GraphsFlows"
ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
Expand Down Expand Up @@ -82,7 +81,6 @@ TupleTools = "1.4"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Expand Down
14 changes: 0 additions & 14 deletions ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl

This file was deleted.

16 changes: 16 additions & 0 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Adapt: Adapt, adapt, adapt_structure
using DataGraphs:
DataGraphs, edge_data, underlying_graph, underlying_graph_type, vertex_data
using Dictionaries: Dictionary
Expand Down Expand Up @@ -86,6 +87,10 @@ function DataGraphs.underlying_graph_type(G::Type{<:AbstractITensorNetwork})
return underlying_graph_type(data_graph_type(G))
end

function ITensors.datatype(tn::AbstractITensorNetwork)
return mapreduce(v -> datatype(tn[v]), promote_type, vertices(tn))
end

# AbstractDataGraphs overloads
function DataGraphs.vertex_data(graph::AbstractITensorNetwork, args...)
return vertex_data(data_graph(graph), args...)
Expand All @@ -102,6 +107,17 @@ function NamedGraphs.ordered_vertices(tn::AbstractITensorNetwork)
return NamedGraphs.ordered_vertices(underlying_graph(tn))
end

function Adapt.adapt_structure(to, tn::AbstractITensorNetwork)
# TODO: Define and use:
#
# @preserve_graph map_vertex_data(adapt(to), tn)
#
# or just:
#
# @preserve_graph map(adapt(to), tn)
return map_vertex_data_preserve_graph(adapt(to), tn)
end

#
# Iteration
#
Expand Down
111 changes: 66 additions & 45 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Adapt: Adapt, adapt, adapt_structure
using Graphs: Graphs, IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag, dot
Expand All @@ -24,24 +25,37 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
end
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))

function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
sequence = contraction_sequence(contract_list; alg="optimal")
updated_messages = contract(contract_list; sequence, kwargs...)
function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor})
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
updated_messages = contract(contract_list; sequence)
message_norm = norm(updated_messages)
if normalize && !iszero(message_norm)
if alg.kwargs.normalize && !iszero(message_norm)
updated_messages /= message_norm
end
return ITensor[updated_messages]
end

function message_update(alg::Algorithm"adapt_update", contract_list::Vector{ITensor})
adapted_contract_list = alg.kwargs.adapt.(contract_list)
updated_messages = message_update(alg.kwargs.alg, adapted_contract_list)
dtype = mapreduce(datatype, promote_type, contract_list)
return map(adapt(dtype), updated_messages)
end

#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
return 1 - f
end

default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
function default_message(datatype::Type{<:AbstractArray}, inds_e)
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
end

function default_message(elt::Type{<:Number}, inds_e)
return default_message(Vector{elt}, inds_e)
end
default_messages(ptn::PartitionedGraph) = Dictionary()
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
Expand All @@ -59,15 +73,13 @@ function default_message(
)
return not_implemented()
end
default_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
return not_implemented()
end
Expand All @@ -80,21 +92,8 @@ end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_edge_sequence(Algorithm(alg), bpc)
end
function default_bp_maxiter(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_bp_maxiter(Algorithm(alg), bpc)
end
function default_message_update_kwargs(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_message_update_kwargs(Algorithm(alg), bpc)
end
default_bp_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented()

function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
Expand Down Expand Up @@ -144,6 +143,30 @@ function incoming_messages(
return incoming_messages(bpc, [partition_vertex]; kwargs...)
end

#Adapt interface for changing device
function map_messages(f, bpc::AbstractBeliefPropagationCache)
bpc = copy(bpc)
for pe in keys(messages(bpc))
set_message!(bpc, pe, f.(message(bpc, pe)))
end
return bpc
end
function map_factors(f, bpc::AbstractBeliefPropagationCache)
bpc = copy(bpc)
for v in vertices(bpc)
@preserve_graph bpc[v] = f(bpc[v])
end
return bpc
end
adapt_messages(to, bpc::AbstractBeliefPropagationCache) = map_messages(adapt(to), bpc)
adapt_factors(to, bpc::AbstractBeliefPropagationCache) = map_factors(adapt(to), bpc)

function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
bpc = adapt_messages(to, bpc)
bpc = adapt_factors(to, bpc)
return bpc
end

#Forward from partitioned graph
for f in [
:(PartitionedGraphs.partitioned_graph),
Expand Down Expand Up @@ -240,16 +263,14 @@ Compute message tensor as product of incoming mts and local state
function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
message_update_function=default_message_update,
message_update_function_kwargs=(;),
message_update_alg=default_message_update_alg(bpc),
kwargs...,
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factors(bpc, vertex)

return message_update_function(
ITensor[incoming_ms; state]; message_update_function_kwargs...
)
return message_update(message_update_alg, ITensor[incoming_ms; state]; kwargs...)
end

function update(
Expand Down Expand Up @@ -306,21 +327,25 @@ More generic interface for update, with default params
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache;
edges=default_edge_sequence(alg, bpc),
maxiter=default_bp_maxiter(alg, bpc),
message_update_kwargs=default_message_update_kwargs(alg, bpc),
tol=nothing,
verbose=false,
message_update_alg=default_message_update_alg(bpc),
kwargs...,
)
compute_error = !isnothing(tol)
if isnothing(maxiter)
compute_error = !isnothing(alg.kwargs.tol)
if isnothing(alg.kwargs.maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
for i in 1:alg.kwargs.maxiter
diff = compute_error ? Ref(0.0) : nothing
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
bpc = update(
alg,
bpc,
alg.kwargs.edge_sequence;
(update_diff!)=diff,
message_update_alg=set_kwargs(message_update_alg),
kwargs...,
)
if compute_error && (diff.x / length(edges)) <= alg.kwargs.tol
if alg.kwargs.verbose
println("BP converged to desired precision after $i iterations.")
end
break
Expand All @@ -329,12 +354,8 @@ function update(
return bpc
end

function update(
bpc::AbstractBeliefPropagationCache;
alg::String=default_message_update_alg(bpc),
kwargs...,
)
return update(Algorithm(alg), bpc; kwargs...)
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
return update(set_kwargs(alg, bpc), bpc; kwargs...)
end

function rescale_messages(
Expand Down
40 changes: 30 additions & 10 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ end
function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end
default_cache_update_kwargs(alg::Algorithm"bp") = (; maxiter=25, tol=1e-8)

function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_tensornetwork
Expand All @@ -60,7 +59,7 @@ end
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages

function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
return default_message(scalartype(bp_cache), linkinds(bp_cache, edge))
return default_message(datatype(bp_cache), linkinds(bp_cache, edge))
end

function Base.copy(bp_cache::BeliefPropagationCache)
Expand All @@ -69,19 +68,40 @@ function Base.copy(bp_cache::BeliefPropagationCache)
)
end

default_message_update_alg(bp_cache::BeliefPropagationCache) = "bp"
function default_update_alg(bp_cache::BeliefPropagationCache)
Algorithm(
"bp";
verbose=false,
maxiter=default_bp_maxiter(bp_cache),
edge_sequence=default_bp_edge_sequence(bp_cache),
tol=nothing,
)
end
function default_message_update_alg(bp_cache::BeliefPropagationCache)
Algorithm("contract"; normalize=true, sequence_alg="optimal")
end
function set_kwargs(alg::Algorithm"contract")
normalize = get(alg.kwargs, :normalize, true)
sequence_alg = get(alg.kwargs, :sequence_alg, "optimal")
return Algorithm("contract"; normalize, sequence_alg)
end
function set_kwargs(alg::Algorithm"adapt_update")
return Algorithm("adapt_update"; adapt=alg.kwargs.adapt, alg=set_kwargs(alg.kwargs.alg))
end
function set_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
verbose = get(alg.kwargs, :verbose, false)
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bp_cache))
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bp_cache))
tol = get(alg.kwargs, :tol, nothing)
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol)
end

function default_bp_maxiter(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
function default_bp_maxiter(bp_cache::BeliefPropagationCache)
return default_bp_maxiter(partitioned_graph(bp_cache))
end
function default_edge_sequence(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
function default_bp_edge_sequence(bp_cache::BeliefPropagationCache)
return default_edge_sequence(partitioned_tensornetwork(bp_cache))
end
function default_message_update_kwargs(
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache
)
return (;)
end

Base.setindex!(bpc::BeliefPropagationCache, factor::ITensor, vertex) = not_implemented()
partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
Expand Down
2 changes: 1 addition & 1 deletion src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function logscalar(
(cache!)=nothing,
cache_construction_kwargs=default_cache_construction_kwargs(alg, tn),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_update_kwargs=(;),
)
if isnothing(cache!)
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
Expand Down
2 changes: 1 addition & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function environment(
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_construction_kwargs=default_cache_construction_kwargs(alg, ptn),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_update_kwargs=(;),
)
if isnothing(cache!)
cache! = Ref(cache(alg, ptn; cache_construction_kwargs...))
Expand Down
2 changes: 1 addition & 1 deletion src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function expect(
ops;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_update_kwargs=(;),
cache_construction_kwargs=(;),
kwargs...,
)
Expand Down
4 changes: 3 additions & 1 deletion src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Adapt: adapt
using ITensors: ITensor, Op, prime, sim
using ITensors.NDTensors: denseblocks
using ITensors.NDTensors: datatype, denseblocks

default_dual_site_index_map = prime
default_dual_link_index_map = sim
Expand Down Expand Up @@ -76,6 +77,7 @@ function BilinearFormNetwork(
O = ITensorNetwork(operator_inds; link_space) do v
return inds -> itensor_identity_map(scalartype(ket), s[v] .=> s_mapped[v])
end
O = adapt(promote_type(datatype(bra), datatype(ket)), O)
return BilinearFormNetwork(O, bra, ket; dual_site_index_map, kwargs...)
end

Expand Down
2 changes: 1 addition & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function VidalITensorNetwork(
ψ::ITensorNetwork;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(Algorithm("bp")),
cache_update_kwargs=(;),
kwargs...,
)
if isnothing(cache!)
Expand Down
Loading
Loading