@@ -935,24 +935,29 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
935935 return tn12
936936end
937937
938- """ Scale each tensor of the network by a scale factor on each vertex"""
939- function scale! (tn:: AbstractITensorNetwork , vertices_weights:: Dictionary )
940- for v in keys (vertices_weights)
941- setindex_preserve_graph! (tn, vertices_weights[v] * tn[v], v)
938+ """ Scale each tensor of the network via a function vertex -> Number"""
939+ function scale! (
940+ weight_function:: Function , tn:: AbstractITensorNetwork ; verts= collect (vertices (tn))
941+ )
942+ for v in verts
943+ setindex_preserve_graph! (tn, weight_function (v) * tn[v], v)
942944 end
943945 return tn
944946end
945947
946- """ Scale each tensor of the network via a function (vertex, ITensor) -> Number"""
947- function scale! (tn:: AbstractITensorNetwork , weight_function:: Function )
948- vs = collect (vertices (tn))
949- vertices_weights = Dictionary (vs, [weight_function (v, tn[v]) for v in vs])
950- return scale! (tn, vertices_weights)
948+ """ Scale each tensor of the network by a scale factor for each vertex in the keys of the dictionary"""
949+ function scale! (tn:: AbstractITensorNetwork , vertices_weights:: Dictionary )
950+ return scale! (v -> vertices_weights[v], tn; verts= keys (vertices_weights))
951+ end
952+
953+ function scale (weight_function:: Function , tn; kwargs... )
954+ tn = copy (tn)
955+ return scale! (weight_function, tn; kwargs... )
951956end
952957
953- function scale (tn, args ... )
958+ function scale (tn, vertices_weights :: Dictionary ; kwargs ... )
954959 tn = copy (tn)
955- return scale! (tn, args ... )
960+ return scale! (tn, vertices_weights; kwargs ... )
956961end
957962
958963Base.:+ (tn1:: AbstractITensorNetwork , tn2:: AbstractITensorNetwork ) = add (tn1, tn2)
0 commit comments