@@ -5,25 +5,28 @@ function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
55end
66
77function rescale (
8- alg:: Algorithm"exact" , tn:: AbstractITensorNetwork , vs= collect (vertices (tn)); kwargs...
8+ alg:: Algorithm"exact" ,
9+ tn:: AbstractITensorNetwork ;
10+ vs_to_rescale= collect (vertices (tn)),
11+ kwargs... ,
912)
1013 logn = logscalar (alg, tn; kwargs... )
11- c = 1.0 / (exp (logn / length (vs )))
14+ c = 1.0 / (exp (logn / length (vs_to_rescale )))
1215 tn = copy (tn)
13- for v in vs
16+ for v in vs_to_rescale
1417 tn[v] *= c
1518 end
1619 return tn
1720end
1821
1922function rescale (
2023 alg:: Algorithm ,
21- tn:: AbstractITensorNetwork ,
22- vs = collect (vertices (tn));
24+ tn:: AbstractITensorNetwork ;
25+ vs_to_rescale = collect (vertices (tn)),
2326 (cache!)= nothing ,
2427 cache_construction_kwargs= default_cache_construction_kwargs (alg, tn),
2528 update_cache= isnothing (cache!),
26- cache_update_kwargs= default_cache_update_kwargs (cache! ),
29+ cache_update_kwargs= default_cache_update_kwargs (alg ),
2730)
2831 if isnothing (cache!)
2932 cache! = Ref (cache (alg, tn; cache_construction_kwargs... ))
@@ -33,29 +36,9 @@ function rescale(
3336 cache![] = update (cache![]; cache_update_kwargs... )
3437 end
3538
36- tn = copy (tn)
37- cache![] = normalize_messages (cache![])
38- vertices_states = Dictionary ()
39- for pv in partitionvertices (cache![])
40- pv_vs = filter (v -> v ∈ vs, vertices (cache![], pv))
41-
42- isempty (pv_vs) && continue
39+ cache![] = rescale (cache![]; vs_to_rescale)
4340
44- vn = region_scalar (cache![], pv)
45- if isreal (vn) && vn < 0
46- tn[first (pv_vs)] *= - 1
47- vn = abs (vn)
48- end
49-
50- vn = vn^ (1 / length (pv_vs))
51- for v in pv_vs
52- tn[v] /= vn
53- set! (vertices_states, v, tn[v])
54- end
55- end
56-
57- cache![] = update_factors (cache![], vertices_states)
58- return tn
41+ return tensornetwork (cache![])
5942end
6043
6144function LinearAlgebra. normalize (tn:: AbstractITensorNetwork ; alg= " exact" , kwargs... )
6548function LinearAlgebra. normalize (
6649 alg:: Algorithm"exact" , tn:: AbstractITensorNetwork ; kwargs...
6750)
68- norm_tn = QuadraticFormNetwork (tn)
69- vs = filter (v -> v ∉ operator_vertices (norm_tn), collect (vertices (norm_tn)))
70- return ket_network (rescale (alg, norm_tn, vs; kwargs... ))
51+ norm_tn = inner_network (tn, tn)
52+ return ket_network (rescale (alg, norm_tn; kwargs... ))
7153end
7254
7355function LinearAlgebra. normalize (
@@ -77,15 +59,19 @@ function LinearAlgebra.normalize(
7759 cache_construction_function= tn ->
7860 cache (alg, tn; default_cache_construction_kwargs (alg, tn)... ),
7961 update_cache= isnothing (cache!),
80- cache_update_kwargs= default_cache_update_kwargs (cache!),
62+ cache_update_kwargs= default_cache_update_kwargs (alg),
63+ cache_construction_kwargs= (;),
8164)
82- norm_tn = QuadraticFormNetwork ( tn)
65+ norm_tn = inner_network (tn, tn)
8366 if isnothing (cache!)
84- cache! = Ref (cache_construction_function ( norm_tn))
67+ cache! = Ref (cache (alg, norm_tn; cache_construction_kwargs ... ))
8568 end
8669
87- vs = filter (v -> v ∉ operator_vertices (norm_tn), collect (vertices (norm_tn)))
88- norm_tn = rescale (alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs)
70+ vs = collect (vertices (tn))
71+ vs_to_rescale = vcat (
72+ [ket_vertex (norm_tn, v) for v in vs], [bra_vertex (norm_tn, v) for v in vs]
73+ )
74+ norm_tn = rescale (alg, norm_tn; vs_to_rescale, cache!, update_cache, cache_update_kwargs)
8975
9076 return ket_network (norm_tn)
9177end
0 commit comments