Skip to content

Commit 73e9e1e

Browse files
committed
Merge remote-tracking branch 'origin/main' into normalize!
2 parents 2cb7f85 + 70a3f7e commit 73e9e1e

File tree

22 files changed

+281
-252
lines changed

22 files changed

+281
-252
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.11.15"
4+
version = "0.11.24"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -62,20 +62,20 @@ DocStringExtensions = "0.9"
6262
EinExprs = "0.6.4"
6363
Graphs = "1.8"
6464
GraphsFlows = "0.1.1"
65-
ITensorMPS = "0.2.2"
66-
ITensors = "0.6.8"
67-
IsApprox = "0.1"
65+
ITensorMPS = "0.3"
66+
ITensors = "0.7"
67+
IsApprox = "0.1, 1, 2"
6868
IterTools = "1.4.0"
69-
KrylovKit = "0.6, 0.7"
69+
KrylovKit = "0.6, 0.7, 0.8"
7070
MacroTools = "0.5"
7171
NDTensors = "0.3"
7272
NamedGraphs = "0.6.0"
73-
OMEinsumContractionOrders = "0.8.3"
73+
OMEinsumContractionOrders = "0.8.3, 0.9"
7474
Observers = "0.2.4"
7575
PackageExtensionCompat = "1"
7676
SerializedElementArrays = "0.1"
7777
SimpleTraits = "0.9"
78-
SparseArrayKit = "0.3"
78+
SparseArrayKit = "0.3, 0.4"
7979
SplitApplyCombine = "1.2"
8080
StaticArrays = "1.5.12"
8181
StructWalk = "0.2"

README.md

Lines changed: 72 additions & 72 deletions
Large diffs are not rendered by default.

examples/test.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using ITensorNetworks: IndsNetwork, siteinds, ttn
2+
using ITensorNetworks.ModelHamiltonians: ising
3+
using ITensors: Index, OpSum, terms, sites
4+
using NamedGraphs.NamedGraphGenerators: named_grid
5+
using NamedGraphs.GraphsExtensions: rem_vertex
6+
7+
function filter_terms(H, verts)
8+
H_new = OpSum()
9+
for term in terms(H)
10+
if isempty(filter(v -> v verts, sites(term)))
11+
H_new += term
12+
end
13+
end
14+
return H_new
15+
end
16+
17+
g = named_grid((8,1))
18+
s = siteinds("S=1/2", g)
19+
H = ising(s)
20+
H_mod = filter_terms(H, [(4,1)])
21+
ttno = ttn(H_mod, s)

src/abstractitensornetwork.jl

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Graphs:
77
add_edge!,
88
add_vertex!,
99
bfs_tree,
10+
center,
1011
dst,
1112
edges,
1213
edgetype,
@@ -18,6 +19,7 @@ using Graphs:
1819
using ITensors:
1920
ITensors,
2021
ITensor,
22+
@Algorithm_str,
2123
addtags,
2224
combiner,
2325
commoninds,
@@ -40,10 +42,10 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
4042
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
4143
using LinearAlgebra: LinearAlgebra, factorize
4244
using MacroTools: @capture
43-
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
45+
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
4446
using NamedGraphs.GraphsExtensions:
4547
, directed_graph, incident_edges, rename_vertices, vertextype
46-
using NDTensors: NDTensors, dim
48+
using NDTensors: NDTensors, dim, Algorithm
4749
using SplitApplyCombine: flatten
4850

4951
abstract type AbstractITensorNetwork{V} <: AbstractDataGraph{V,ITensor,ITensor} end
@@ -584,7 +586,9 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs.
584586
end
585587

586588
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
587-
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
589+
function gauge_edge(
590+
alg::Algorithm"orthogonalize", tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...
591+
)
588592
# tn = factorize(tn, edge; kwargs...)
589593
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
590594
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
@@ -598,23 +602,43 @@ function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwa
598602
return tn
599603
end
600604

601-
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
602-
return _orthogonalize_edge(tn, edge; kwargs...)
605+
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
606+
function gauge_walk(
607+
alg::Algorithm, tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
608+
)
609+
tn = copy(tn)
610+
for edge in edges
611+
tn = gauge_edge(alg, tn, edge; kwargs...)
612+
end
613+
return tn
614+
end
615+
616+
function gauge_walk(alg::Algorithm, tn::AbstractITensorNetwork, edge::Pair; kwargs...)
617+
return gauge_edge(alg::Algorithm, tn, edgetype(tn)(edge); kwargs...)
603618
end
604619

605-
function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
606-
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
620+
function gauge_walk(
621+
alg::Algorithm, tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...
622+
)
623+
return gauge_walk(alg, tn, edgetype(tn).(edges); kwargs...)
607624
end
608625

609-
# Orthogonalize an ITensorNetwork towards a source vertex, treating
626+
# Gauge a ITensorNetwork towards a region, treating
610627
# the network as a tree spanned by a spanning tree.
611-
# TODO: Rename `tree_orthogonalize`.
612-
function ITensorMPS.orthogonalize::AbstractITensorNetwork, source_vertex)
613-
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
614-
for e in spanning_tree_edges
615-
ψ = orthogonalize(ψ, e)
616-
end
617-
return ψ
628+
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region::Vector)
629+
region_center =
630+
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
631+
path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center)
632+
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
633+
return gauge_walk(alg, ψ, path)
634+
end
635+
636+
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region)
637+
return tree_gauge(alg, ψ, [region])
638+
end
639+
640+
function tree_orthogonalize::AbstractITensorNetwork, region; kwargs...)
641+
return tree_gauge(Algorithm("orthogonalize"), ψ, region; kwargs...)
618642
end
619643

620644
# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
@@ -759,7 +783,7 @@ end
759783
# Link dimensions
760784
#
761785

762-
function ITensors.maxlinkdim(tn::AbstractITensorNetwork)
786+
function ITensorMPS.maxlinkdim(tn::AbstractITensorNetwork)
763787
md = 1
764788
for e in edges(tn)
765789
md = max(md, linkdim(tn, e))

src/apply.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function ITensors.apply(
200200
v⃗ = neighbor_vertices(ψ, o)
201201
if length(v⃗) == 1
202202
if ortho
203-
ψ = orthogonalize(ψ, v⃗[1])
203+
ψ = tree_orthogonalize(ψ, v⃗[1])
204204
end
205205
oψᵥ = apply(o, ψ[v⃗[1]])
206206
if normalize
@@ -215,7 +215,7 @@ function ITensors.apply(
215215
error("Vertices where the gates are being applied must be neighbors for now.")
216216
end
217217
if ortho
218-
ψ = orthogonalize(ψ, v⃗[1])
218+
ψ = tree_orthogonalize(ψ, v⃗[1])
219219
end
220220
if variational_optimization_only || !is_product_env
221221
ψᵥ₁, ψᵥ₂ = full_update_bp(

src/caches/beliefpropagationcache.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ using NDTensors: NDTensors
1616

1717
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
1818
default_messages(ptn::PartitionedGraph) = Dictionary()
19-
default_message_norm(m::ITensor) = norm(m)
2019
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
2120
sequence = optimal_contraction_sequence(contract_list)
2221
updated_messages = contract(contract_list; sequence, kwargs...)
@@ -107,7 +106,7 @@ end
107106

108107
function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
109108
mts = messages(bp_cache)
110-
return get(mts, edge, default_message(bp_cache, edge))
109+
return get(() -> default_message(bp_cache, edge), mts, edge)
111110
end
112111
function messages(bp_cache::BeliefPropagationCache, edges; kwargs...)
113112
return map(edge -> message(bp_cache, edge; kwargs...), edges)
@@ -153,24 +152,16 @@ end
153152
function environment(bp_cache::BeliefPropagationCache, verts::Vector)
154153
partition_verts = partitionvertices(bp_cache, verts)
155154
messages = environment(bp_cache, partition_verts)
156-
central_tensors = ITensor[
157-
tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts)
158-
]
155+
central_tensors = factors(bp_cache, setdiff(vertices(bp_cache, partition_verts), verts))
159156
return vcat(messages, central_tensors)
160157
end
161158

162-
function factors(bp_cache::BeliefPropagationCache, vertices)
163-
tn = tensornetwork(bp_cache)
164-
return map(vertex -> tn[vertex], vertices)
165-
end
166-
167-
function factor(bp_cache::BeliefPropagationCache, vertex)
168-
return only(factors(bp_cache, [vertex]))
159+
function factors(bp_cache::BeliefPropagationCache, verts::Vector)
160+
return ITensor[tensornetwork(bp_cache)[v] for v in verts]
169161
end
170162

171163
function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex)
172-
ptn = partitioned_tensornetwork(bp_cache)
173-
return collect(eachtensor(subgraph(ptn, vertex)))
164+
return factors(bp_cache, vertices(bp_cache, vertex))
174165
end
175166

176167
"""

src/inner.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ITensors: inner, scalar, loginner
1+
using ITensors: inner, scalar
2+
using ITensorMPS: ITensorMPS, loginner
23
using LinearAlgebra: norm, norm_sqr
34

45
default_contract_alg(tns::Tuple) = "bp"
@@ -53,7 +54,7 @@ function ITensors.inner(
5354
return scalar(tn; sequence)
5455
end
5556

56-
function ITensors.loginner(
57+
function ITensorMPS.loginner(
5758
ϕ::AbstractITensorNetwork,
5859
ψ::AbstractITensorNetwork;
5960
alg=default_contract_alg((ϕ, ψ)),
@@ -62,7 +63,7 @@ function ITensors.loginner(
6263
return loginner(Algorithm(alg), ϕ, ψ; kwargs...)
6364
end
6465

65-
function ITensors.loginner(
66+
function ITensorMPS.loginner(
6667
ϕ::AbstractITensorNetwork,
6768
A::AbstractITensorNetwork,
6869
ψ::AbstractITensorNetwork;
@@ -72,13 +73,13 @@ function ITensors.loginner(
7273
return loginner(Algorithm(alg), ϕ, A, ψ; kwargs...)
7374
end
7475

75-
function ITensors.loginner(
76+
function ITensorMPS.loginner(
7677
alg::Algorithm"exact", ϕ::AbstractITensorNetwork, ψ::AbstractITensorNetwork; kwargs...
7778
)
7879
return log(inner(alg, ϕ, ψ); kwargs...)
7980
end
8081

81-
function ITensors.loginner(
82+
function ITensorMPS.loginner(
8283
alg::Algorithm"exact",
8384
ϕ::AbstractITensorNetwork,
8485
A::AbstractITensorNetwork,
@@ -88,7 +89,7 @@ function ITensors.loginner(
8889
return log(inner(alg, ϕ, A, ψ); kwargs...)
8990
end
9091

91-
function ITensors.loginner(
92+
function ITensorMPS.loginner(
9293
alg::Algorithm"bp",
9394
ϕ::AbstractITensorNetwork,
9495
ψ::AbstractITensorNetwork;
@@ -99,7 +100,7 @@ function ITensors.loginner(
99100
return logscalar(alg, tn; kwargs...)
100101
end
101102

102-
function ITensors.loginner(
103+
function ITensorMPS.loginner(
103104
alg::Algorithm"bp",
104105
ϕ::AbstractITensorNetwork,
105106
A::AbstractITensorNetwork,

src/solvers/alternating_update/alternating_update.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function alternating_update(
99
nsites, # define default for each level of solver implementation
1010
updater, # this specifies the update performed locally
1111
outputlevel=default_outputlevel(),
12-
region_printer=nothing,
13-
sweep_printer=nothing,
12+
region_printer=default_region_printer,
13+
sweep_printer=default_sweep_printer,
1414
(sweep_observer!)=nothing,
1515
(region_observer!)=nothing,
1616
root_vertex=GraphsExtensions.default_root_vertex(init_state),
@@ -59,7 +59,7 @@ function alternating_update(
5959
(sweep_observer!)=nothing,
6060
sweep_printer=default_sweep_printer,#?
6161
(region_observer!)=nothing,
62-
region_printer=nothing,
62+
region_printer=default_region_printer,
6363
)
6464
state = copy(init_state)
6565
@assert !isnothing(sweep_plans)

src/solvers/alternating_update/region_update.jl

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,3 @@
1-
#ToDo: generalize beyond 2-site
2-
#ToDo: remove concept of orthogonality center for generality
3-
function current_ortho(sweep_plan, which_region_update)
4-
regions = first.(sweep_plan)
5-
region = regions[which_region_update]
6-
current_verts = support(region)
7-
if !isa(region, AbstractEdge) && length(region) == 1
8-
return only(current_verts)
9-
end
10-
if which_region_update == length(regions)
11-
# look back by one should be sufficient, but may be brittle?
12-
overlapping_vertex = only(
13-
intersect(current_verts, support(regions[which_region_update - 1]))
14-
)
15-
return overlapping_vertex
16-
else
17-
# look forward
18-
other_regions = filter(
19-
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
20-
)
21-
# find the first region that has overlapping support with current region
22-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
23-
if isnothing(ind)
24-
# look backward
25-
other_regions = reverse(
26-
filter(
27-
x -> !(issetequal(x, current_verts)),
28-
support.(regions[1:(which_region_update - 1)]),
29-
),
30-
)
31-
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
32-
end
33-
@assert !isnothing(ind)
34-
future_verts = union(support(other_regions[ind]))
35-
# return ortho_ceter as the vertex in current region that does not overlap with following one
36-
overlapping_vertex = intersect(current_verts, future_verts)
37-
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
38-
return nonoverlapping_vertex
39-
end
40-
end
41-
421
function region_update(
432
projected_operator,
443
state;
@@ -64,14 +23,13 @@ function region_update(
6423

6524
# ToDo: remove orthogonality center on vertex for generality
6625
# region carries same information
67-
ortho_vertex = current_ortho(sweep_plan, which_region_update)
6826
if !isnothing(transform_operator)
6927
projected_operator = transform_operator(
7028
state, projected_operator; outputlevel, transform_operator_kwargs...
7129
)
7230
end
7331
state, projected_operator, phi = extracter(
74-
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
32+
state, projected_operator, region; extracter_kwargs..., internal_kwargs
7533
)
7634
# create references, in case solver does (out-of-place) modify PH or state
7735
state! = Ref(state)
@@ -97,9 +55,8 @@ function region_update(
9755
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
9856
# so noiseterm is a solver
9957
#end
100-
state, spec = inserter(
101-
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
102-
)
58+
#if isa(region, AbstractEdge) &&
59+
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
10360
all_kwargs = (;
10461
which_region_update,
10562
sweep_plan,

src/solvers/contract.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Graphs: nv, vertices
2-
using ITensors: ITensors, linkinds, sim
2+
using ITensors: ITensors, sim
3+
using ITensorMPS: linkinds
34
using ITensors.NDTensors: Algorithm, @Algorithm_str, contract
45
using NamedGraphs: vertextype
56

0 commit comments

Comments
 (0)