Skip to content

Commit d460884

Browse files
committed
Improve orthogonalize. Simplify code
1 parent d50387b commit d460884

File tree

5 files changed

+71
-92
lines changed

5 files changed

+71
-92
lines changed

src/abstractitensornetwork.jl

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -585,36 +585,61 @@ end
585585

586586
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
587587
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
588+
return _orthogonalize_edges(tn, [edge]; kwargs...)
589+
end
590+
591+
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
592+
function _orthogonalize_edges(
593+
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
594+
)
588595
# tn = factorize(tn, edge; kwargs...)
589596
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
590597
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
591598
# return contract(tn, new_vertex => dst(edge))
592599
tn = copy(tn)
593-
left_inds = uniqueinds(tn, edge)
594-
ltags = tags(tn, edge)
595-
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
596-
tn[src(edge)] = X
597-
tn[dst(edge)] *= Y
600+
for edge in edges
601+
left_inds = uniqueinds(tn, edge)
602+
ltags = tags(tn, edge)
603+
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
604+
tn[src(edge)] = X
605+
tn[dst(edge)] *= Y
606+
end
598607
return tn
599608
end
600609

601610
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
602611
return _orthogonalize_edge(tn, edge; kwargs...)
603612
end
604613

614+
function ITensorMPS.orthogonalize(
615+
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
616+
)
617+
return _orthogonalize_edges(tn, edges; kwargs...)
618+
end
619+
605620
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
606621
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
607622
end
608623

609-
# Orthogonalize an ITensorNetwork towards a source vertex, treating
624+
function ITensorMPS.orthogonalize(
625+
tn::AbstractITensorNetwork, edges::Vector{Pair}; kwargs...
626+
)
627+
return orthogonalize(tn, edgetype(tn).(edges); kwargs...)
628+
end
629+
630+
# Orthogonalize an ITensorNetwork towards a region, treating
610631
# the network as a tree spanned by a spanning tree.
611632
# TODO: Rename `tree_orthogonalize`.
612-
function ITensorMPS.orthogonalize::AbstractITensorNetwork, source_vertex)
613-
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
614-
for e in spanning_tree_edges
615-
ψ = orthogonalize(ψ, e)
616-
end
617-
return ψ
633+
function ITensorMPS.orthogonalize::AbstractITensorNetwork, region::Vector)
634+
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region))
635+
spanning_tree_edges = filter(
636+
e -> !(src(e) region && dst(e) region), spanning_tree_edges
637+
)
638+
return orthogonalize(ψ, spanning_tree_edges)
639+
end
640+
641+
function ITensorMPS.orthogonalize::AbstractITensorNetwork, region)
642+
return orthogonalize(ψ, [region])
618643
end
619644

620645
# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?

src/solvers/alternating_update/region_update.jl

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,3 @@
1-
#ToDo: generalize beyond 2-site
2-
#ToDo: remove concept of orthogonality center for generality
3-
function current_ortho(sweep_plan, which_region_update)
4-
regions = first.(sweep_plan)
5-
region = regions[which_region_update]
6-
current_verts = support(region)
7-
if !isa(region, AbstractEdge) && length(region) == 1
8-
return only(current_verts)
9-
end
10-
# look forward
11-
other_regions = filter(
12-
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
13-
)
14-
# find the first region that has overlapping support with current region
15-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
16-
if isnothing(ind)
17-
# look backward
18-
other_regions = reverse(
19-
filter(
20-
x -> !(issetequal(x, current_verts)), support.(regions[1:(which_region_update - 1)])
21-
),
22-
)
23-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
24-
end
25-
@assert !isnothing(ind)
26-
future_verts = union(support(other_regions[ind]))
27-
# return ortho_ceter as the vertex in current region that does not overlap with following one
28-
overlapping_vertex = intersect(current_verts, future_verts)
29-
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
30-
return nonoverlapping_vertex
31-
end
32-
331
function region_update(
342
projected_operator,
353
state;
@@ -55,14 +23,13 @@ function region_update(
5523

5624
# ToDo: remove orthogonality center on vertex for generality
5725
# region carries same information
58-
ortho_vertex = current_ortho(sweep_plan, which_region_update)
5926
if !isnothing(transform_operator)
6027
projected_operator = transform_operator(
6128
state, projected_operator; outputlevel, transform_operator_kwargs...
6229
)
6330
end
6431
state, projected_operator, phi = extracter(
65-
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
32+
state, projected_operator, region; extracter_kwargs..., internal_kwargs
6633
)
6734
# create references, in case solver does (out-of-place) modify PH or state
6835
state! = Ref(state)
@@ -88,9 +55,7 @@ function region_update(
8855
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
8956
# so noiseterm is a solver
9057
#end
91-
state, spec = inserter(
92-
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
93-
)
58+
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
9459
all_kwargs = (;
9560
which_region_update,
9661
sweep_plan,

src/solvers/extract/extract.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
# insert_local_tensors takes that tensor and factorizes it back
88
# apart and puts it back into the network.
99
#
10-
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
11-
state = orthogonalize(state, ortho)
10+
function default_extracter(state, projected_operator, region; internal_kwargs)
1211
if isa(region, AbstractEdge)
13-
other_vertex = only(setdiff(support(region), [ortho]))
14-
left_inds = uniqueinds(state[ortho], state[other_vertex])
12+
vsrc, vdst = src(region), dst(region)
13+
state = orthogonalize(state, vsrc)
14+
left_inds = uniqueinds(state[vsrc], state[vdst])
1515
#ToDo: replace with call to factorize
1616
U, S, V = svd(
17-
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
17+
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
1818
)
19-
state[ortho] = U
19+
state[vsrc] = U
2020
local_tensor = S * V
2121
else
22+
state = orthogonalize(state, region)
2223
local_tensor = prod(state[v] for v in region)
2324
end
2425
projected_operator = position(projected_operator, state, region)

src/solvers/insert/insert.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
function default_inserter(
77
state::AbstractTTN,
88
phi::ITensor,
9-
region,
10-
ortho_vert;
9+
region;
1110
normalize=false,
1211
maxdim=nothing,
1312
mindim=nothing,
@@ -16,16 +15,14 @@ function default_inserter(
1615
)
1716
state = copy(state)
1817
spec = nothing
19-
other_vertex = setdiff(support(region), [ortho_vert])
20-
if !isempty(other_vertex)
21-
v = only(other_vertex)
22-
e = edgetype(state)(ortho_vert, v)
23-
indsTe = inds(state[ortho_vert])
18+
if length(region) == 2
19+
v = last(region)
20+
e = edgetype(state)(first(region), last(region))
21+
indsTe = inds(state[first(region)])
2422
L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff)
25-
state[ortho_vert] = L
26-
23+
state[first(region)] = L
2724
else
28-
v = ortho_vert
25+
v = only(region)
2926
end
3027
state[v] = phi
3128
state = set_ortho_region(state, [v])
@@ -44,8 +41,7 @@ function default_inserter(
4441
normalize=false,
4542
internal_kwargs,
4643
)
47-
v = only(setdiff(support(region), [ortho]))
48-
state[v] *= phi
49-
state = set_ortho_region(state, [v])
44+
state[dst(region)] *= phi
45+
state = set_ortho_region(state, [dst(region)])
5046
return state, nothing
5147
end

src/treetensornetworks/abstracttreetensornetwork.jl

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using Graphs: has_vertex
22
using NamedGraphs.GraphsExtensions:
3-
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
3+
GraphsExtensions,
4+
edge_path,
5+
leaf_vertices,
6+
post_order_dfs_edges,
7+
post_order_dfs_vertices,
8+
a_star
9+
using NamedGraphs: namedgraph_a_star
410
using IsApprox: IsApprox, Approx
511
using ITensors: @Algorithm_str, directsum, hasinds, permute, plev
612
using ITensorMPS: linkind, loginner, lognorm, orthogonalize
@@ -29,30 +35,16 @@ function set_ortho_region(tn::AbstractTTN, new_region)
2935
return error("Not implemented")
3036
end
3137

32-
#
33-
# Orthogonalization
34-
#
35-
36-
function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
37-
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
38-
return tn
39-
end
40-
# TODO: Rewrite this in a more general way.
41-
if isone(length(ortho_region(tn)))
42-
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
43-
else
44-
edge_list = post_order_dfs_edges(tn, ortho_center)
45-
end
46-
for e in edge_list
47-
tn = orthogonalize(tn, e)
38+
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
39+
paths = [
40+
namedgraph_a_star(underlying_graph(ttn), rp, r) for r in region for
41+
rp in ortho_region(ttn)
42+
]
43+
path = unique(reduce(vcat, paths))
44+
if !isempty(path)
45+
ttn = typeof(ttn)(orthogonalize(ITensorNetwork(ttn), path; kwargs...))
4846
end
49-
return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center]))
50-
end
51-
52-
# For ambiguity error
53-
54-
function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
55-
return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...))
47+
return set_ortho_region(ttn, region)
5648
end
5749

5850
#

0 commit comments

Comments
 (0)