@@ -391,19 +391,19 @@ end
391391
392392LinearAlgebra. adjoint (tn:: Union{IndsNetwork,AbstractITensorNetwork} ) = prime (tn)
393393
394- function map_vertex_data (f, tn:: AbstractITensorNetwork )
394+ # TODO : Define preserve graph version in DataGraphs.jl and @preserve_graph map_vertex_data(f, tn)`
395+ function map_vertex_data_preserve_graph (f, tn:: AbstractITensorNetwork )
395396 tn = copy (tn)
396397 for v in vertices (tn)
397- tn[v] = f (tn[v])
398+ @preserve_graph tn[v] = f (tn[v])
398399 end
399400 return tn
400401end
401402
402- # TODO : Define `@preserve_graph map_vertex_data(f, tn)`
403- function map_vertex_data_preserve_graph (f, tn:: AbstractITensorNetwork )
404- tn = copy (tn)
405- for v in vertices (tn)
406- @preserve_graph tn[v] = f (tn[v])
403+ # TODO : Define this and an out-of-place version in DataGraphs.jl
404+ function map_vertices_preserve_graph! (f, tn:: AbstractITensorNetwork ; vertices= vertices (tn))
405+ for v in vertices
406+ @preserve_graph tn[v] = f (v)
407407 end
408408 return tn
409409end
@@ -941,10 +941,7 @@ function scale!(
941941 tn:: AbstractITensorNetwork ;
942942 vertices= collect (Graphs. vertices (tn)),
943943)
944- for v in vertices
945- setindex_preserve_graph! (tn, weight_function (v) * tn[v], v)
946- end
947- return tn
944+ return map_vertices_preserve_graph! (v -> weight_function (v) * tn[v], tn; vertices)
948945end
949946
950947""" Scale each tensor of the network by a scale factor for each vertex in the keys of the dictionary"""
0 commit comments