Skip to content

Commit 2141045

Browse files
committed
Improvements
1 parent 4394909 commit 2141045

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

src/abstractitensornetwork.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -935,24 +935,29 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
935935
return tn12
936936
end
937937

938-
""" Scale each tensor of the network by a scale factor on each vertex"""
939-
function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary)
940-
for v in keys(vertices_weights)
941-
setindex_preserve_graph!(tn, vertices_weights[v] * tn[v], v)
938+
""" Scale each tensor of the network via a function vertex -> Number"""
939+
function scale!(
940+
weight_function::Function, tn::AbstractITensorNetwork; verts=collect(vertices(tn))
941+
)
942+
for v in verts
943+
setindex_preserve_graph!(tn, weight_function(v) * tn[v], v)
942944
end
943945
return tn
944946
end
945947

946-
""" Scale each tensor of the network via a function (vertex, ITensor) -> Number"""
947-
function scale!(tn::AbstractITensorNetwork, weight_function::Function)
948-
vs = collect(vertices(tn))
949-
vertices_weights = Dictionary(vs, [weight_function(v, tn[v]) for v in vs])
950-
return scale!(tn, vertices_weights)
948+
""" Scale each tensor of the network by a scale factor for each vertex in the keys of the dictionary"""
949+
function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary)
950+
return scale!(v -> vertices_weights[v], tn; verts=keys(vertices_weights))
951+
end
952+
953+
function scale(weight_function::Function, tn; kwargs...)
954+
tn = copy(tn)
955+
return scale!(weight_function, tn; kwargs...)
951956
end
952957

953-
function scale(tn, args...)
958+
function scale(tn, vertices_weights::Dictionary; kwargs...)
954959
tn = copy(tn)
955-
return scale!(tn, args...)
960+
return scale!(tn, vertices_weights; kwargs...)
956961
end
957962

958963
Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,16 +305,16 @@ end
305305
function rescale_partitions(
306306
bpc::AbstractBeliefPropagationCache,
307307
partitions::Vector;
308-
verts_to_rescale::Vector=vertices(bpc, partitions),
308+
verts::Vector=vertices(bpc, partitions),
309309
)
310310
bpc = copy(bpc)
311311
tn = tensornetwork(bpc)
312-
norms = map(v -> inv(norm(tn[v])), verts_to_rescale)
313-
scale!(bpc, Dictionary(verts_to_rescale, norms))
312+
norms = map(v -> inv(norm(tn[v])), verts)
313+
scale!(bpc, Dictionary(verts, norms))
314314

315315
vertices_weights = Dictionary()
316316
for pv in partitions
317-
pv_vs = filter(v -> v verts_to_rescale, vertices(bpc, pv))
317+
pv_vs = filter(v -> v verts, vertices(bpc, pv))
318318
isempty(pv_vs) && continue
319319

320320
vn = region_scalar(bpc, pv)

src/normalize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra
1+
using LinearAlgebra: normalize
22

33
function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
44
return rescale(Algorithm(alg), tn; kwargs...)

0 commit comments

Comments
 (0)