Skip to content

Commit b2d2784

Browse files
committed
Improvements
1 parent 756bae0 commit b2d2784

File tree

3 files changed

+22
-29
lines changed

3 files changed

+22
-29
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,21 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525
end
2626
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))
2727

28-
function message_update_function(
29-
alg::Algorithm"contract",
30-
contract_list::Vector{ITensor};
31-
normalize=alg.kwargs.normalize,
32-
sequence_alg=alg.kwargs.sequence_alg,
33-
)
34-
sequence = contraction_sequence(contract_list; alg=sequence_alg)
28+
function message_update(alg::Algorithm"contract", contract_list::Vector{ITensor};)
29+
sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
3530
updated_messages = contract(contract_list; sequence)
3631
message_norm = norm(updated_messages)
37-
if normalize && !iszero(message_norm)
32+
if alg.kwargs.normalize && !iszero(message_norm)
3833
updated_messages /= message_norm
3934
end
4035
return ITensor[updated_messages]
4136
end
4237

43-
function message_update_function(
44-
alg::Algorithm"contract_custom_device",
45-
contract_list::Vector{ITensor};
46-
normalize=alg.kwargs.normalize,
47-
sequence_alg=alg.kwargs.sequence_alg,
48-
custom_device_adapt=alg.kwargs.adapt,
49-
)
50-
adapted_contract_list = custom_device_adapt.(contract_list)
51-
updated_messages = message_update_function(
52-
Algorithm("contract"), adapted_contract_list; normalize, sequence_alg
53-
)
38+
function message_update(alg::Algorithm"adapt_update", contract_list::Vector{ITensor};)
39+
adapted_contract_list = alg.kwargs.adapt.(contract_list)
40+
updated_messages = message_update(alg.kwargs.alg, adapted_contract_list)
5441
dtype = datatype(first(contract_list))
55-
return ITensor[adapt(dtype, updated_message) for updated_message in updated_messages]
42+
return map(adapt(dtype), updated_messages)
5643
end
5744

5845
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
@@ -62,8 +49,12 @@ function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
6249
return 1 - f
6350
end
6451

65-
function default_message(datatype, elt, inds_e)
66-
ITensor[adapt(datatype, denseblocks(delta(elt, i))) for i in inds_e]
52+
function default_message(datatype::Type{<:AbstractArray}, inds_e)
53+
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
54+
end
55+
56+
function default_message(elt::Type{<:Number}, inds_e)
57+
return default_message(Vector{elt}, inds_e)
6758
end
6859
default_messages(ptn::PartitionedGraph) = Dictionary()
6960
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : 30
@@ -153,20 +144,22 @@ function incoming_messages(
153144
end
154145

155146
#Adapt interface for changing device
156-
function adapt_messages(to, bpc::AbstractBeliefPropagationCache)
147+
function map_messages(map, bpc::AbstractBeliefPropagationCache)
157148
bpc = copy(bpc)
158149
for pe in keys(messages(bpc))
159-
set_message!(bpc, pe, adapt(to).(message(bpc, pe)))
150+
set_message!(bpc, pe, map.(message(bpc, pe)))
160151
end
161152
return bpc
162153
end
163-
function adapt_factors(to, bpc::AbstractBeliefPropagationCache)
154+
function map_factors(to, bpc::AbstractBeliefPropagationCache)
164155
bpc = copy(bpc)
165156
for v in vertices(bpc)
166-
@preserve_graph bpc[v] = adapt(to).(bpc[v])
157+
@preserve_graph bpc[v] = map(bpc[v])
167158
end
168159
return bpc
169160
end
161+
adapt_messages(to, bpc::AbstractBeliefPropagationCache) = map_messages(adapt(to), bpc)
162+
adapt_factors(to, bpc::AbstractBeliefPropagationCache) = map_factors(adapt(to), bpc)
170163

171164
function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache)
172165
bpc = adapt_messages(to, bpc)
@@ -277,7 +270,7 @@ function updated_message(
277270
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
278271
state = factors(bpc, vertex)
279272

280-
return message_update_function(message_update_alg, ITensor[incoming_ms; state]; kwargs...)
273+
return message_update(message_update_alg, ITensor[incoming_ms; state]; kwargs...)
281274
end
282275

283276
function update(

src/caches/beliefpropagationcache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
6060

6161
function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
62-
return default_message(datatype(bp_cache), scalartype(bp_cache), linkinds(bp_cache, edge))
62+
return default_message(datatype(bp_cache), linkinds(bp_cache, edge))
6363
end
6464

6565
function Base.copy(bp_cache::BeliefPropagationCache)

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Glob = "1.3.1"
4444
Graphs = "1.12.0"
4545
GraphsFlows = "0.1.1"
4646
ITensorMPS = "0.3.6"
47-
ITensorNetworks = "0.13.0"
47+
ITensorNetworks = "0.14.0"
4848
ITensors = "0.7, 0.8, 0.9"
4949
KrylovKit = "0.8, 0.9, 0.10"
5050
LinearAlgebra = "1.10.0"

0 commit comments

Comments
 (0)