Skip to content

Commit 1bb92fd

Browse files
committed
Define in terms of a map_vertices_function
1 parent e7e6134 commit 1bb92fd

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/abstractitensornetwork.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -391,19 +391,19 @@ end
391391

392392
LinearAlgebra.adjoint(tn::Union{IndsNetwork,AbstractITensorNetwork}) = prime(tn)
393393

394-
function map_vertex_data(f, tn::AbstractITensorNetwork)
394+
# TODO: Define preserve graph version in DataGraphs.jl and @preserve_graph map_vertex_data(f, tn)`
395+
function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork)
395396
tn = copy(tn)
396397
for v in vertices(tn)
397-
tn[v] = f(tn[v])
398+
@preserve_graph tn[v] = f(tn[v])
398399
end
399400
return tn
400401
end
401402

402-
# TODO: Define `@preserve_graph map_vertex_data(f, tn)`
403-
function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork)
404-
tn = copy(tn)
405-
for v in vertices(tn)
406-
@preserve_graph tn[v] = f(tn[v])
403+
# TODO: Define this and an out-of-place version in DataGraphs.jl
404+
function map_vertices_preserve_graph!(f, tn::AbstractITensorNetwork; vertices=vertices(tn))
405+
for v in vertices
406+
@preserve_graph tn[v] = f(v)
407407
end
408408
return tn
409409
end
@@ -941,10 +941,7 @@ function scale!(
941941
tn::AbstractITensorNetwork;
942942
vertices=collect(Graphs.vertices(tn)),
943943
)
944-
for v in vertices
945-
setindex_preserve_graph!(tn, weight_function(v) * tn[v], v)
946-
end
947-
return tn
944+
return map_vertices_preserve_graph!(v -> weight_function(v) * tn[v], tn; vertices)
948945
end
949946

950947
""" Scale each tensor of the network by a scale factor for each vertex in the keys of the dictionary"""

0 commit comments

Comments
 (0)