Skip to content

Commit b0d6aa7

Browse files
authored
Current ortho fix (#208)
1 parent 33d3006 commit b0d6aa7

File tree

10 files changed

+103
-134
lines changed

10 files changed

+103
-134
lines changed

src/abstractitensornetwork.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Graphs:
77
add_edge!,
88
add_vertex!,
99
bfs_tree,
10+
center,
1011
dst,
1112
edges,
1213
edgetype,
@@ -40,7 +41,7 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
4041
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
4142
using LinearAlgebra: LinearAlgebra, factorize
4243
using MacroTools: @capture
43-
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
44+
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
4445
using NamedGraphs.GraphsExtensions:
4546
, directed_graph, incident_edges, rename_vertices, vertextype
4647
using NDTensors: NDTensors, dim
@@ -584,37 +585,49 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs.
584585
end
585586

586587
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
587-
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
588+
function orthogonalize_walk(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
589+
return orthogonalize_walk(tn, [edge]; kwargs...)
590+
end
591+
592+
function orthogonalize_walk(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
593+
return orthogonalize_walk(tn, edgetype(tn)(edge); kwargs...)
594+
end
595+
596+
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
597+
function orthogonalize_walk(
598+
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
599+
)
588600
# tn = factorize(tn, edge; kwargs...)
589601
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
590602
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
591603
# return contract(tn, new_vertex => dst(edge))
592604
tn = copy(tn)
593-
left_inds = uniqueinds(tn, edge)
594-
ltags = tags(tn, edge)
595-
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
596-
tn[src(edge)] = X
597-
tn[dst(edge)] *= Y
605+
for edge in edges
606+
left_inds = uniqueinds(tn, edge)
607+
ltags = tags(tn, edge)
608+
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
609+
tn[src(edge)] = X
610+
tn[dst(edge)] *= Y
611+
end
598612
return tn
599613
end
600614

601-
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
602-
return _orthogonalize_edge(tn, edge; kwargs...)
615+
function orthogonalize_walk(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...)
616+
return orthogonalize_walk(tn, edgetype(tn).(edges); kwargs...)
603617
end
604618

605-
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
606-
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
619+
# Orthogonalize an ITensorNetwork towards a region, treating
620+
# the network as a tree spanned by a spanning tree.
621+
function tree_orthogonalize::AbstractITensorNetwork, region::Vector)
622+
region_center =
623+
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
624+
path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center)
625+
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
626+
return orthogonalize_walk(ψ, path)
607627
end
608628

609-
# Orthogonalize an ITensorNetwork towards a source vertex, treating
610-
# the network as a tree spanned by a spanning tree.
611-
# TODO: Rename `tree_orthogonalize`.
612-
function ITensorMPS.orthogonalize::AbstractITensorNetwork, source_vertex)
613-
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
614-
for e in spanning_tree_edges
615-
ψ = orthogonalize(ψ, e)
616-
end
617-
return ψ
629+
function tree_orthogonalize::AbstractITensorNetwork, region)
630+
return tree_orthogonalize(ψ, [region])
618631
end
619632

620633
# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?

src/apply.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function ITensors.apply(
200200
v⃗ = neighbor_vertices(ψ, o)
201201
if length(v⃗) == 1
202202
if ortho
203-
ψ = orthogonalize(ψ, v⃗[1])
203+
ψ = tree_orthogonalize(ψ, v⃗[1])
204204
end
205205
oψᵥ = apply(o, ψ[v⃗[1]])
206206
if normalize
@@ -215,7 +215,7 @@ function ITensors.apply(
215215
error("Vertices where the gates are being applied must be neighbors for now.")
216216
end
217217
if ortho
218-
ψ = orthogonalize(ψ, v⃗[1])
218+
ψ = tree_orthogonalize(ψ, v⃗[1])
219219
end
220220
if variational_optimization_only || !is_product_env
221221
ψᵥ₁, ψᵥ₂ = full_update_bp(

src/solvers/alternating_update/region_update.jl

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,3 @@
1-
#ToDo: generalize beyond 2-site
2-
#ToDo: remove concept of orthogonality center for generality
3-
function current_ortho(sweep_plan, which_region_update)
4-
regions = first.(sweep_plan)
5-
region = regions[which_region_update]
6-
current_verts = support(region)
7-
if !isa(region, AbstractEdge) && length(region) == 1
8-
return only(current_verts)
9-
end
10-
if which_region_update == length(regions)
11-
# look back by one should be sufficient, but may be brittle?
12-
overlapping_vertex = only(
13-
intersect(current_verts, support(regions[which_region_update - 1]))
14-
)
15-
return overlapping_vertex
16-
else
17-
# look forward
18-
other_regions = filter(
19-
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
20-
)
21-
# find the first region that has overlapping support with current region
22-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
23-
if isnothing(ind)
24-
# look backward
25-
other_regions = reverse(
26-
filter(
27-
x -> !(issetequal(x, current_verts)),
28-
support.(regions[1:(which_region_update - 1)]),
29-
),
30-
)
31-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
32-
end
33-
@assert !isnothing(ind)
34-
future_verts = union(support(other_regions[ind]))
35-
# return ortho_ceter as the vertex in current region that does not overlap with following one
36-
overlapping_vertex = intersect(current_verts, future_verts)
37-
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
38-
return nonoverlapping_vertex
39-
end
40-
end
41-
421
function region_update(
432
projected_operator,
443
state;
@@ -64,14 +23,13 @@ function region_update(
6423

6524
# ToDo: remove orthogonality center on vertex for generality
6625
# region carries same information
67-
ortho_vertex = current_ortho(sweep_plan, which_region_update)
6826
if !isnothing(transform_operator)
6927
projected_operator = transform_operator(
7028
state, projected_operator; outputlevel, transform_operator_kwargs...
7129
)
7230
end
7331
state, projected_operator, phi = extracter(
74-
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
32+
state, projected_operator, region; extracter_kwargs..., internal_kwargs
7533
)
7634
# create references, in case solver does (out-of-place) modify PH or state
7735
state! = Ref(state)
@@ -97,9 +55,8 @@ function region_update(
9755
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
9856
# so noiseterm is a solver
9957
#end
100-
state, spec = inserter(
101-
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
102-
)
58+
#if isa(region, AbstractEdge) &&
59+
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
10360
all_kwargs = (;
10461
which_region_update,
10562
sweep_plan,

src/solvers/extract/extract.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,20 @@
77
# insert_local_tensors takes that tensor and factorizes it back
88
# apart and puts it back into the network.
99
#
10-
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
11-
state = orthogonalize(state, ortho)
10+
11+
function default_extracter(state, projected_operator, region; internal_kwargs)
1212
if isa(region, AbstractEdge)
13-
other_vertex = only(setdiff(support(region), [ortho]))
14-
left_inds = uniqueinds(state[ortho], state[other_vertex])
15-
#ToDo: replace with call to factorize
13+
# TODO: add functionality for orthogonalizing onto a bond so that can be called instead
14+
vsrc, vdst = src(region), dst(region)
15+
state = orthogonalize(state, vsrc)
16+
left_inds = uniqueinds(state[vsrc], state[vdst])
1617
U, S, V = svd(
17-
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
18+
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
1819
)
19-
state[ortho] = U
20+
state[vsrc] = U
2021
local_tensor = S * V
2122
else
23+
state = orthogonalize(state, region)
2224
local_tensor = prod(state[v] for v in region)
2325
end
2426
projected_operator = position(projected_operator, state, region)

src/solvers/insert/insert.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
function default_inserter(
77
state::AbstractTTN,
88
phi::ITensor,
9-
region,
10-
ortho_vert;
9+
region;
1110
normalize=false,
1211
maxdim=nothing,
1312
mindim=nothing,
@@ -16,16 +15,14 @@ function default_inserter(
1615
)
1716
state = copy(state)
1817
spec = nothing
19-
other_vertex = setdiff(support(region), [ortho_vert])
20-
if !isempty(other_vertex)
21-
v = only(other_vertex)
22-
e = edgetype(state)(ortho_vert, v)
23-
indsTe = inds(state[ortho_vert])
18+
if length(region) == 2
19+
v = last(region)
20+
e = edgetype(state)(first(region), last(region))
21+
indsTe = inds(state[first(region)])
2422
L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff)
25-
state[ortho_vert] = L
26-
23+
state[first(region)] = L
2724
else
28-
v = ortho_vert
25+
v = only(region)
2926
end
3027
state[v] = phi
3128
state = set_ortho_region(state, [v])
@@ -36,16 +33,14 @@ end
3633
function default_inserter(
3734
state::AbstractTTN,
3835
phi::ITensor,
39-
region::NamedEdge,
40-
ortho;
36+
region::NamedEdge;
4137
cutoff=nothing,
4238
maxdim=nothing,
4339
mindim=nothing,
4440
normalize=false,
4541
internal_kwargs,
4642
)
47-
v = only(setdiff(support(region), [ortho]))
48-
state[v] *= phi
49-
state = set_ortho_region(state, [v])
43+
state[dst(region)] *= phi
44+
state = set_ortho_region(state, [dst(region)])
5045
return state, nothing
5146
end

src/solvers/sweep_plans/sweep_plans.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ end
1313

1414
support(r) = r
1515

16-
function reverse_region(edges, which_edge; nsites=1, region_kwargs=(;))
16+
function reverse_region(edges, which_edge; reverse_edge=false, nsites=1, region_kwargs=(;))
1717
current_edge = edges[which_edge]
1818
if nsites == 1
19-
return [(current_edge, region_kwargs)]
19+
!reverse_edge && return [(current_edge, region_kwargs)]
20+
reverse_edge && return [(reverse(current_edge), region_kwargs)]
2021
elseif nsites == 2
2122
if last(edges) == current_edge
2223
return ()
@@ -62,25 +63,24 @@ function forward_sweep(
6263
dir::Base.ForwardOrdering,
6364
graph::AbstractGraph;
6465
root_vertex=GraphsExtensions.default_root_vertex(graph),
66+
reverse_edges=false,
6567
region_kwargs,
6668
reverse_kwargs=region_kwargs,
6769
reverse_step=false,
6870
kwargs...,
6971
)
7072
edges = post_order_dfs_edges(graph, root_vertex)
71-
regions = collect(
72-
flatten(map(i -> forward_region(edges, i; region_kwargs, kwargs...), eachindex(edges)))
73-
)
74-
73+
regions = map(eachindex(edges)) do i
74+
forward_region(edges, i; region_kwargs, kwargs...)
75+
end
76+
regions = collect(flatten(regions))
7577
if reverse_step
76-
reverse_regions = collect(
77-
flatten(
78-
map(
79-
i -> reverse_region(edges, i; region_kwargs=reverse_kwargs, kwargs...),
80-
eachindex(edges),
81-
),
82-
),
83-
)
78+
reverse_regions = map(eachindex(edges)) do i
79+
reverse_region(
80+
edges, i; reverse_edge=reverse_edges, region_kwargs=reverse_kwargs, kwargs...
81+
)
82+
end
83+
reverse_regions = collect(flatten(reverse_regions))
8484
_check_reverse_sweeps(regions, reverse_regions, graph; kwargs...)
8585
regions = interleave(regions, reverse_regions)
8686
end
@@ -90,7 +90,7 @@ end
9090

9191
#ToDo: is there a better name for this? unidirectional_sweep? traversal?
9292
function forward_sweep(dir::Base.ReverseOrdering, args...; kwargs...)
93-
return reverse(forward_sweep(Base.Forward, args...; kwargs...))
93+
return reverse(forward_sweep(Base.Forward, args...; reverse_edges=true, kwargs...))
9494
end
9595

9696
function default_sweep_plans(

src/tebd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function tebd(
2323
ψ = apply(u⃗, ψ; cutoff, maxdim, normalize=true, ortho, kwargs...)
2424
if ortho
2525
for v in vertices(ψ)
26-
ψ = orthogonalize(ψ, v)
26+
ψ = tree_orthogonalize(ψ, v)
2727
end
2828
end
2929
end

src/treetensornetworks/abstracttreetensornetwork.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using Graphs: has_vertex
22
using NamedGraphs.GraphsExtensions:
3-
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
3+
GraphsExtensions,
4+
edge_path,
5+
leaf_vertices,
6+
post_order_dfs_edges,
7+
post_order_dfs_vertices,
8+
a_star
9+
using NamedGraphs: namedgraph_a_star, steiner_tree
410
using IsApprox: IsApprox, Approx
511
using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev
612
using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize
@@ -29,30 +35,23 @@ function set_ortho_region(tn::AbstractTTN, new_region)
2935
return error("Not implemented")
3036
end
3137

32-
#
33-
# Orthogonalization
34-
#
35-
36-
function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
37-
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
38-
return tn
39-
end
40-
# TODO: Rewrite this in a more general way.
41-
if isone(length(ortho_region(tn)))
42-
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
43-
else
44-
edge_list = post_order_dfs_edges(tn, ortho_center)
45-
end
46-
for e in edge_list
47-
tn = orthogonalize(tn, e)
38+
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
39+
issetequal(region, ortho_region(ttn)) && return ttn
40+
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
41+
path = post_order_dfs_edges(st, first(region))
42+
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
43+
if !isempty(path)
44+
ttn = typeof(ttn)(orthogonalize_walk(ITensorNetwork(ttn), path; kwargs...))
4845
end
49-
return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center]))
46+
return set_ortho_region(ttn, region)
5047
end
5148

52-
# For ambiguity error
49+
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...)
50+
return orthogonalize(ttn, [region]; kwargs...)
51+
end
5352

54-
function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
55-
return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...))
53+
function tree_orthogonalize(ttn::AbstractTTN, args...; kwargs...)
54+
return orthogonalize(ttn, args...; kwargs...)
5655
end
5756

5857
#

0 commit comments

Comments
 (0)