11using LinearAlgebra
22
3- function rescale (tn:: AbstractITensorNetwork , c:: Number , vs= collect (vertices (tn)))
3+ function rescale (tn:: AbstractITensorNetwork ; alg= " exact" , kwargs... )
4+ return rescale (Algorithm (alg), tn; kwargs... )
5+ end
6+
7+ function rescale (
8+ alg:: Algorithm"exact" , tn:: AbstractITensorNetwork , vs= collect (vertices (tn)); kwargs...
9+ )
10+ logn = logscalar (alg, tn; kwargs... )
11+ c = 1.0 / (exp (logn / length (vs)))
412 tn = copy (tn)
513 for v in vs
614 tn[v] *= c
715 end
816 return tn
917end
1018
11- function LinearAlgebra. normalize (tn:: AbstractITensorNetwork ; alg= " exact" , kwargs... )
12- return normalize (Algorithm (alg), tn; kwargs... )
13- end
14-
15- function LinearAlgebra. normalize (alg:: Algorithm"exact" , tn:: AbstractITensorNetwork )
16- norm_tn = QuadraticFormNetwork (tn)
17- c = exp (logscalar (alg, norm_tn) / (2 * length (vertices (tn))))
18- return rescale (tn, 1 / c)
19- end
20-
21- function LinearAlgebra. normalize (
19+ function rescale (
2220 alg:: Algorithm"bp" ,
23- tn:: AbstractITensorNetwork ;
21+ tn:: AbstractITensorNetwork ,
22+ vs= collect (vertices (tn));
2423 (cache!)= nothing ,
2524 update_cache= isnothing (cache!),
2625 cache_update_kwargs= default_cache_update_kwargs (cache!),
2726)
2827 if isnothing (cache!)
29- cache! = Ref (BeliefPropagationCache (QuadraticFormNetwork (tn )))
28+ cache! = Ref (BeliefPropagationCache (tn, group (v -> v, vertices (tn) )))
3029 end
3130
3231 if update_cache
@@ -35,19 +34,57 @@ function LinearAlgebra.normalize(
3534
3635 tn = copy (tn)
3736 cache![] = normalize_messages (cache![])
38- norm_tn = tensornetwork (cache![])
39-
4037 vertices_states = Dictionary ()
41- for v in vertices (tn)
42- v_ket, v_bra = ket_vertex (norm_tn, v), bra_vertex (norm_tn, v)
43- pv = only (partitionvertices (cache![], [v_ket]))
38+ for pv in partitionvertices (cache![])
39+ pv_vs = filter (v -> v ∈ vs, vertices (cache![], pv))
40+
41+ isempty (pv_vs) && continue
42+
4443 vn = region_scalar (cache![], pv)
45- norm_tn = rescale (norm_tn, 1 / sqrt (vn), [v_ket, v_bra])
46- set! (vertices_states, v_ket, norm_tn[v_ket])
47- set! (vertices_states, v_bra, norm_tn[v_bra])
44+ if isreal (vn) && vn < 0
45+ tn[first (pv_vs)] *= - 1
46+ vn = abs (vn)
47+ end
48+
49+ vn = vn^ (1 / length (pv_vs))
50+ for v in pv_vs
51+ tn[v] /= vn
52+ set! (vertices_states, v, tn[v])
53+ end
4854 end
4955
5056 cache![] = update_factors (cache![], vertices_states)
57+ return tn
58+ end
59+
60+ function LinearAlgebra. normalize (tn:: AbstractITensorNetwork ; alg= " exact" , kwargs... )
61+ return normalize (Algorithm (alg), tn; kwargs... )
62+ end
63+
64+ function LinearAlgebra. normalize (
65+ alg:: Algorithm"exact" , tn:: AbstractITensorNetwork ; kwargs...
66+ )
67+ norm_tn = QuadraticFormNetwork (tn)
68+ vs = filter (v -> v ∉ operator_vertices (norm_tn), collect (vertices (norm_tn)))
69+ return ket_network (rescale (alg, norm_tn, vs; kwargs... ))
70+ end
71+
72+ function LinearAlgebra. normalize (
73+ alg:: Algorithm ,
74+ tn:: AbstractITensorNetwork ;
75+ (cache!)= nothing ,
76+ cache_construction_function= tn ->
77+ cache (alg, tn; default_cache_construction_kwargs (alg, tn)... ),
78+ update_cache= isnothing (cache!),
79+ cache_update_kwargs= default_cache_update_kwargs (cache!),
80+ )
81+ norm_tn = QuadraticFormNetwork (tn)
82+ if isnothing (cache!)
83+ cache! = Ref (cache_construction_function (norm_tn))
84+ end
85+
86+ vs = filter (v -> v ∉ operator_vertices (norm_tn), collect (vertices (norm_tn)))
87+ norm_tn = rescale (alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs)
5188
5289 return ket_network (norm_tn)
5390end
0 commit comments