Skip to content

Commit 5bc9888

Browse files
JoeyT1994mtfishman
andauthored
Allow BP to work on gpus and various kwarg improvements (#253)
Co-authored-by: Matt Fishman <[email protected]>
1 parent 27a2072 commit 5bc9888

19 files changed

+169
-121
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.13.17"
4+
version = "0.14.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
8+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1011
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -33,15 +34,13 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3334
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
3435

3536
[weakdeps]
36-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3737
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
3838
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
3939
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
4040
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
4141
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
4242

4343
[extensions]
44-
ITensorNetworksAdaptExt = "Adapt"
4544
ITensorNetworksEinExprsExt = "EinExprs"
4645
ITensorNetworksGraphsFlowsExt = "GraphsFlows"
4746
ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
@@ -82,7 +81,6 @@ TupleTools = "1.4"
8281
julia = "1.10"
8382

8483
[extras]
85-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8684
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
8785
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
8886
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
77
Documenter = "1.10.0"
8-
ITensorNetworks = "0.13.0"
8+
ITensorNetworks = "0.14.0"
99
Literate = "2.20.1"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33

44
[compat]
5-
ITensorNetworks = "0.13.2"
5+
ITensorNetworks = "0.14.0"

ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/abstractitensornetwork.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Adapt: Adapt, adapt, adapt_structure
12
using DataGraphs:
23
DataGraphs, edge_data, underlying_graph, underlying_graph_type, vertex_data
34
using Dictionaries: Dictionary
@@ -86,6 +87,10 @@ function DataGraphs.underlying_graph_type(G::Type{<:AbstractITensorNetwork})
8687
return underlying_graph_type(data_graph_type(G))
8788
end
8889

90+
function ITensors.datatype(tn::AbstractITensorNetwork)
91+
return mapreduce(v -> datatype(tn[v]), promote_type, vertices(tn))
92+
end
93+
8994
# AbstractDataGraphs overloads
9095
function DataGraphs.vertex_data(graph::AbstractITensorNetwork, args...)
9196
return vertex_data(data_graph(graph), args...)
@@ -102,6 +107,17 @@ function NamedGraphs.ordered_vertices(tn::AbstractITensorNetwork)
102107
return NamedGraphs.ordered_vertices(underlying_graph(tn))
103108
end
104109

110+
function Adapt.adapt_structure(to, tn::AbstractITensorNetwork)
111+
# TODO: Define and use:
112+
#
113+
# @preserve_graph map_vertex_data(adapt(to), tn)
114+
#
115+
# or just:
116+
#
117+
# @preserve_graph map(adapt(to), tn)
118+
return map_vertex_data_preserve_graph(adapt(to), tn)
119+
end
120+
105121
#
106122
# Iteration
107123
#

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Adapt: Adapt, adapt, adapt_structure
12
using Graphs: Graphs, IsDirected
23
using SplitApplyCombine: group
34
using LinearAlgebra: diag, dot
@@ -24,24 +25,20 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2425
end
2526
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))
2627

27-
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
28-
sequence = contraction_sequence(contract_list; alg="optimal")
29-
updated_messages = contract(contract_list; sequence, kwargs...)
30-
message_norm = norm(updated_messages)
31-
if normalize && !iszero(message_norm)
32-
updated_messages /= message_norm
33-
end
34-
return ITensor[updated_messages]
35-
end
36-
3728
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
3829
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
3930
lhs, rhs = contract(message_a), contract(message_b)
4031
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
4132
return 1 - f
4233
end
4334

44-
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
35+
function default_message(datatype::Type{<:AbstractArray}, inds_e)
36+
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
37+
end
38+
39+
function default_message(elt::Type{<:Number}, inds_e)
40+
return default_message(Vector{elt}, inds_e)
41+
end
4542
default_messages(ptn::PartitionedGraph) = Dictionary()
4643
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
4744
@traitfn function default_bp_maxiter(g::::IsDirected)
@@ -59,15 +56,13 @@ function default_message(
5956
)
6057
return not_implemented()
6158
end
59+
default_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
6260
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
6361
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
6462
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
6563
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
6664
return not_implemented()
6765
end
68-
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
69-
return not_implemented()
70-
end
7166
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
7267
return not_implemented()
7368
end
@@ -80,21 +75,8 @@ end
8075
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
8176
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()
8277

83-
function default_edge_sequence(
84-
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
85-
)
86-
return default_edge_sequence(Algorithm(alg), bpc)
87-
end
88-
function default_bp_maxiter(
89-
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
90-
)
91-
return default_bp_maxiter(Algorithm(alg), bpc)
92-
end
93-
function default_message_update_kwargs(
94-
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
95-
)
96-
return default_message_update_kwargs(Algorithm(alg), bpc)
97-
end
78+
default_bp_edge_sequence(bpc::AbstractBeliefPropagationCache) = not_implemented()
79+
default_bp_maxiter(bpc::AbstractBeliefPropagationCache) = not_implemented()
9880

9981
function tensornetwork(bpc::AbstractBeliefPropagationCache)
10082
return unpartitioned_graph(partitioned_tensornetwork(bpc))
@@ -144,6 +126,36 @@ function incoming_messages(
144126
return incoming_messages(bpc, [partition_vertex]; kwargs...)
145127
end
146128

129+
#Adapt interface for changing device
130+
function map_messages(
131+
f, bpc::AbstractBeliefPropagationCache, pes=collect(keys(messages(bpc)))
132+
)
133+
bpc = copy(bpc)
134+
for pe in pes
135+
set_message!(bpc, pe, f.(message(bpc, pe)))
136+
end
137+
return bpc
138+
end
139+
function map_factors(f, bpc::AbstractBeliefPropagationCache, vs=vertices(bpc))
140+
bpc = copy(bpc)
141+
for v in vs
142+
@preserve_graph bpc[v] = f(bpc[v])
143+
end
144+
return bpc
145+
end
146+
function adapt_messages(to, bpc::AbstractBeliefPropagationCache, args...)
147+
return map_messages(adapt(to), bpc, args...)
148+
end
149+
function adapt_factors(to, bpc::AbstractBeliefPropagationCache, args...)
150+
return map_factors(adapt(to), bpc, args...)
151+
end
152+
153+
function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
154+
bpc = adapt_messages(to, bpc)
155+
bpc = adapt_factors(to, bpc)
156+
return bpc
157+
end
158+
147159
#Forward from partitioned graph
148160
for f in [
149161
:(PartitionedGraphs.partitioned_graph),
@@ -234,44 +246,63 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
234246
return delete_messages(bpc, [pe])
235247
end
236248

237-
"""
238-
Compute message tensor as product of incoming mts and local state
239-
"""
240249
function updated_message(
241-
bpc::AbstractBeliefPropagationCache,
242-
edge::PartitionEdge;
243-
message_update_function=default_message_update,
244-
message_update_function_kwargs=(;),
250+
alg::Algorithm"contract", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
245251
)
246252
vertex = src(edge)
247253
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
248254
state = factors(bpc, vertex)
255+
contract_list = ITensor[incoming_ms; state]
256+
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
257+
updated_messages = contract(contract_list; sequence)
258+
message_norm = norm(updated_messages)
259+
if alg.kwargs.normalize && !iszero(message_norm)
260+
updated_messages /= message_norm
261+
end
262+
return ITensor[updated_messages]
263+
end
249264

250-
return message_update_function(
251-
ITensor[incoming_ms; state]; message_update_function_kwargs...
265+
function updated_message(
266+
alg::Algorithm"adapt_update", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
267+
)
268+
incoming_pes = setdiff(
269+
boundary_partitionedges(bpc, [src(edge)]; dir=:in), [reverse(edge)]
252270
)
271+
adapted_bpc = adapt_messages(alg.kwargs.adapt, bpc, incoming_pes)
272+
adapted_bpc = adapt_factors(alg.kwargs.adapt, bpc, vertices(bpc, src(edge)))
273+
updated_messages = updated_message(alg.kwargs.alg, adapted_bpc, edge)
274+
dtype = mapreduce(datatype, promote_type, message(bpc, edge))
275+
return map(adapt(dtype), updated_messages)
276+
end
277+
278+
function updated_message(
279+
bpc::AbstractBeliefPropagationCache,
280+
edge::PartitionEdge;
281+
alg=default_message_update_alg(bpc),
282+
kwargs...,
283+
)
284+
return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bpc, edge)
253285
end
254286

255-
function update(
256-
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
287+
function update_message(
288+
message_update_alg::Algorithm, bpc::AbstractBeliefPropagationCache, edge::PartitionEdge
257289
)
258-
return set_message(bpc, edge, updated_message(bpc, edge; kwargs...))
290+
return set_message(bpc, edge, updated_message(message_update_alg, bpc, edge))
259291
end
260292

261293
"""
262294
Do a sequential update of the message tensors on `edges`
263295
"""
264-
function update(
265-
alg::Algorithm,
296+
function update_iteration(
297+
alg::Algorithm"bp",
266298
bpc::AbstractBeliefPropagationCache,
267299
edges::Vector;
268300
(update_diff!)=nothing,
269-
kwargs...,
270301
)
271302
bpc = copy(bpc)
272303
for e in edges
273304
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
274-
bpc = update(alg, bpc, e; kwargs...)
305+
bpc = update_message(alg.kwargs.message_update_alg, bpc, e)
275306
if !isnothing(update_diff!)
276307
update_diff![] += message_diff(message(bpc, e), prev_message)
277308
end
@@ -284,15 +315,15 @@ Do parallel updates between groups of edges of all message tensors
284315
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
285316
mts relevant to that group.
286317
"""
287-
function update(
288-
alg::Algorithm,
318+
function update_iteration(
319+
alg::Algorithm"bp",
289320
bpc::AbstractBeliefPropagationCache,
290321
edge_groups::Vector{<:Vector{<:PartitionEdge}};
291-
kwargs...,
322+
(update_diff!)=nothing,
292323
)
293324
new_mts = empty(messages(bpc))
294325
for edges in edge_groups
295-
bpc_t = update(alg, bpc, edges; kwargs...)
326+
bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!))
296327
for e in edges
297328
set!(new_mts, e, message(bpc_t, e))
298329
end
@@ -303,24 +334,16 @@ end
303334
"""
304335
More generic interface for update, with default params
305336
"""
306-
function update(
307-
alg::Algorithm,
308-
bpc::AbstractBeliefPropagationCache;
309-
edges=default_edge_sequence(alg, bpc),
310-
maxiter=default_bp_maxiter(alg, bpc),
311-
message_update_kwargs=default_message_update_kwargs(alg, bpc),
312-
tol=nothing,
313-
verbose=false,
314-
)
315-
compute_error = !isnothing(tol)
316-
if isnothing(maxiter)
337+
function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache)
338+
compute_error = !isnothing(alg.kwargs.tol)
339+
if isnothing(alg.kwargs.maxiter)
317340
error("You need to specify a number of iterations for BP!")
318341
end
319-
for i in 1:maxiter
342+
for i in 1:alg.kwargs.maxiter
320343
diff = compute_error ? Ref(0.0) : nothing
321-
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
322-
if compute_error && (diff.x / length(edges)) <= tol
323-
if verbose
344+
bpc = update_iteration(alg, bpc, alg.kwargs.edge_sequence; (update_diff!)=diff)
345+
if compute_error && (diff.x / length(alg.kwargs.edge_sequence)) <= alg.kwargs.tol
346+
if alg.kwargs.verbose
324347
println("BP converged to desired precision after $i iterations.")
325348
end
326349
break
@@ -329,12 +352,8 @@ function update(
329352
return bpc
330353
end
331354

332-
function update(
333-
bpc::AbstractBeliefPropagationCache;
334-
alg::String=default_message_update_alg(bpc),
335-
kwargs...,
336-
)
337-
return update(Algorithm(alg), bpc; kwargs...)
355+
function update(bpc::AbstractBeliefPropagationCache; alg=default_update_alg(bpc), kwargs...)
356+
return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc)
338357
end
339358

340359
function rescale_messages(

0 commit comments

Comments
 (0)