Skip to content

Commit 2452869

Browse files
jack-dunhamJack Dunhammtfishman
authored
Changes to the iterator interface (#255)
Co-authored-by: Jack Dunham <[email protected]> Co-authored-by: Matt Fishman <[email protected]>
1 parent 5628922 commit 2452869

30 files changed

+733
-2300
lines changed

src/solvers/abstract_problem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11

22
abstract type AbstractProblem end
33

4-
set_truncation_info(P::AbstractProblem, args...; kws...) = P
4+
set_truncation_info!(P::AbstractProblem, args...; kws...) = P

src/solvers/adapters.jl

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
1+
"""
2+
struct PauseAfterIncrement{S<:AbstractNetworkIterator}
13
2-
#
3-
# TupleRegionIterator
4-
#
5-
# Adapts outputs to be (region, region_kwargs) tuples
6-
#
7-
# More generic design? maybe just assuming RegionIterator
8-
# or its outputs implement some interface function that
9-
# generates each tuple?
10-
#
11-
12-
mutable struct TupleRegionIterator{RegionIter}
13-
region_iterator::RegionIter
4+
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
5+
process. This allows one to manually call a custom `compute!` or insert their own code it in
6+
the loop body in place of `compute!`.
7+
"""
8+
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
9+
parent::S
1410
end
1511

16-
region_iterator(T::TupleRegionIterator) = T.region_iterator
12+
islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
13+
state(adapter::IncrementOnly) = state(adapter.parent)
14+
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
15+
compute!(adapter::IncrementOnly) = adapter
1716

18-
function Base.iterate(T::TupleRegionIterator, which=1)
19-
state = iterate(region_iterator(T), which)
20-
isnothing(state) && return nothing
21-
(current_region, region_kwargs) = current_region_plan(region_iterator(T))
22-
return (current_region, region_kwargs), last(state)
23-
end
17+
IncrementOnly(adapter::IncrementOnly) = adapter
2418

2519
"""
26-
region_tuples(R::RegionIterator)
20+
struct EachRegion{SweepIterator} <: AbstractNetworkIterator
2721
28-
The `region_tuples` adapter converts a RegionIterator into an
29-
iterator which outputs a tuple of the form (current_region, current_region_kwargs)
30-
at each step.
22+
Adapter that flattens each region iterator in the parent sweep iterator into a single
23+
iterator.
3124
"""
32-
region_tuples(R::RegionIterator) = TupleRegionIterator(R)
25+
struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator
26+
parent::SI
27+
end
28+
29+
# In keeping with Julia convention.
30+
eachregion(iter::SweepIterator) = EachRegion(iter)
31+
32+
# Essential definitions
33+
function islaststep(adapter::EachRegion)
34+
region_iter = region_iterator(adapter.parent)
35+
return islaststep(adapter.parent) && islaststep(region_iter)
36+
end
37+
function increment!(adapter::EachRegion)
38+
region_iter = region_iterator(adapter.parent)
39+
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
40+
return adapter
41+
end
42+
function compute!(adapter::EachRegion)
43+
region_iter = region_iterator(adapter.parent)
44+
compute!(region_iter)
45+
return adapter
46+
end

src/solvers/applyexp.jl

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Printf: @printf
2-
using Accessors: @set
32

43
@kwdef mutable struct ApplyExpProblem{State} <: AbstractProblem
54
operator
@@ -11,66 +10,69 @@ operator(A::ApplyExpProblem) = A.operator
1110
state(A::ApplyExpProblem) = A.state
1211
current_exponent(A::ApplyExpProblem) = A.current_exponent
1312
function current_time(A::ApplyExpProblem)
14-
t = im*A.current_exponent
13+
t = im * A.current_exponent
1514
return iszero(imag(t)) ? real(t) : t
1615
end
1716

18-
set_operator(A::ApplyExpProblem, operator) = (@set A.operator = operator)
19-
set_state(A::ApplyExpProblem, state) = (@set A.state = state)
20-
set_current_exponent(A::ApplyExpProblem, exponent) = (@set A.current_exponent = exponent)
21-
22-
function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...)
23-
return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...)
17+
# Rename region_plan
18+
function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...)
19+
# The `exponent_step` kwarg for the `update!` function needs some pre-processing.
20+
return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...)
2421
end
2522

26-
function update(
27-
prob::ApplyExpProblem,
28-
local_state,
29-
region_iterator;
23+
function update!(
24+
region_iter::RegionIterator{<:ApplyExpProblem},
25+
local_state;
3026
nsites,
3127
exponent_step,
3228
solver=runge_kutta_solver,
33-
outputlevel,
34-
kws...,
3529
)
36-
iszero(abs(exponent_step)) && return prob, local_state
30+
prob = problem(region_iter)
31+
32+
if iszero(abs(exponent_step))
33+
return region_iter, local_state
34+
end
3735

38-
local_state, info = solver(
39-
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
36+
solver_kwargs = region_kwargs(solver, region_iter)
37+
38+
local_state, _ = solver(
39+
x -> optimal_map(operator(prob), x), exponent_step, local_state; solver_kwargs...
4040
)
41-
if nsites==1
42-
curr_reg = current_region(region_iterator)
43-
next_reg = next_region(region_iterator)
41+
if nsites == 1
42+
curr_reg = current_region(region_iter)
43+
next_reg = next_region(region_iter)
4444
if !isnothing(next_reg) && next_reg != curr_reg
4545
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
4646
v1, v2 = src(next_edge), dst(next_edge)
4747
psi = copy(state(prob))
4848
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
49-
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
50-
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
51-
local_state = psi[v1]*R_t
49+
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
50+
R_t, _ = solver(
51+
x -> optimal_map(shifted_operator, x), -exponent_step, R; solver_kwargs...
52+
)
53+
local_state = psi[v1] * R_t
5254
end
5355
end
5456

55-
prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
57+
prob.current_exponent += exponent_step
5658

57-
return prob, local_state
59+
return region_iter, local_state
5860
end
5961

60-
function sweep_callback(
61-
problem::ApplyExpProblem;
62+
function default_sweep_callback(
63+
sweep_iterator::SweepIterator{<:ApplyExpProblem};
6264
exponent_description="exponent",
63-
outputlevel,
64-
sweep,
65-
nsweeps,
65+
outputlevel=0,
6666
process_time=identity,
67-
kws...,
6867
)
6968
if outputlevel >= 1
69+
the_problem = problem(sweep_iterator)
7070
@printf(
71-
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
71+
" Current %s = %s, ",
72+
exponent_description,
73+
process_time(current_exponent(the_problem))
7274
)
73-
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
75+
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
7476
println()
7577
flush(stdout)
7678
end
@@ -79,19 +81,20 @@ end
7981
function applyexp(
8082
init_prob::AbstractProblem,
8183
exponents;
82-
extract_kwargs=(;),
83-
update_kwargs=(;),
84-
insert_kwargs=(;),
85-
outputlevel=0,
86-
nsites=1,
84+
sweep_callback=default_sweep_callback,
8785
order=4,
88-
kws...,
86+
nsites=2,
87+
sweep_kwargs...,
8988
)
9089
exponent_steps = diff([zero(eltype(exponents)); exponents])
91-
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
92-
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
93-
sweep_iter = sweep_iterator(init_prob, kws_array)
94-
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
90+
91+
kws_array = [
92+
(; order, nsites, sweep_kwargs..., exponent_step) for exponent_step in exponent_steps
93+
]
94+
sweep_iter = SweepIterator(init_prob, kws_array)
95+
96+
converged_prob = problem(sweep_solve!(sweep_callback, sweep_iter))
97+
9598
return state(converged_prob)
9699
end
97100

@@ -111,11 +114,10 @@ function time_evolve(
111114
time_points,
112115
init_state;
113116
process_time=process_real_times,
114-
sweep_callback=(
115-
a...; k...
116-
)->sweep_callback(a...; exponent_description="time", process_time, k...),
117-
kws...,
117+
sweep_callback=iter ->
118+
default_sweep_callback(iter; exponent_description="time", process_time),
119+
sweep_kwargs...,
118120
)
119-
exponents = [-im*t for t in time_points]
120-
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
121+
exponents = [-im * t for t in time_points]
122+
return applyexp(operator, exponents, init_state; sweep_callback, sweep_kwargs...)
121123
end

src/solvers/eigsolve.jl

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Accessors: @set
21
using Printf: @printf
32
using ITensors: truncerror
43

@@ -14,42 +13,43 @@ state(E::EigsolveProblem) = E.state
1413
operator(E::EigsolveProblem) = E.operator
1514
max_truncerror(E::EigsolveProblem) = E.max_truncerror
1615

17-
set_operator(E::EigsolveProblem, operator) = (@set E.operator = operator)
18-
set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue)
19-
set_state(E::EigsolveProblem, state) = (@set E.state = state)
20-
set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror)
21-
22-
function set_truncation_info(E::EigsolveProblem; spectrum=nothing)
16+
function set_truncation_info!(E::EigsolveProblem; spectrum=nothing)
2317
if !isnothing(spectrum)
24-
E = set_max_truncerror(E, max(max_truncerror(E), truncerror(spectrum)))
18+
E.max_truncerror = max(max_truncerror(E), truncerror(spectrum))
2519
end
2620
return E
2721
end
2822

29-
function update(
30-
prob::EigsolveProblem,
31-
local_state,
32-
region_iterator;
33-
outputlevel,
23+
function update!(
24+
region_iter::RegionIterator{<:EigsolveProblem},
25+
local_state;
26+
outputlevel=0,
3427
solver=eigsolve_solver,
35-
kws...,
3628
)
37-
eigval, local_state = solver->optimal_map(operator(prob), ψ), local_state; kws...)
38-
prob = set_eigenvalue(prob, eigval)
29+
prob = problem(region_iter)
30+
31+
eigval, local_state = solver(
32+
ψ -> optimal_map(operator(prob), ψ), local_state; region_kwargs(solver, region_iter)...
33+
)
34+
35+
prob.eigenvalue = eigval
36+
3937
if outputlevel >= 2
40-
@printf(
41-
" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob)
42-
)
38+
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
4339
end
44-
return prob, local_state
40+
return region_iter, local_state
4541
end
4642

47-
function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
43+
function default_sweep_callback(
44+
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0
45+
)
4846
if outputlevel >= 1
49-
if nsweeps >= 10
50-
@printf("After sweep %02d/%d ", sweep, nsweeps)
47+
nsweeps = length(sweep_iterator)
48+
current_sweep = sweep_iterator.which_sweep
49+
if length(sweep_iterator) >= 10
50+
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
5151
else
52-
@printf("After sweep %d/%d ", sweep, nsweeps)
52+
@printf("After sweep %d/%d ", current_sweep, nsweeps)
5353
end
5454
@printf("eigenvalue=%.12f", eigenvalue(problem))
5555
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
@@ -60,24 +60,22 @@ function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, k
6060
end
6161

6262
function eigsolve(
63-
operator,
64-
init_state;
65-
nsweeps,
66-
nsites=1,
67-
outputlevel=0,
68-
extract_kwargs=(;),
69-
update_kwargs=(;),
70-
insert_kwargs=(;),
71-
kws...,
63+
operator, init_state; nsweeps, nsites=1, outputlevel=0, factorize_kwargs, sweep_kwargs...
7264
)
7365
init_prob = EigsolveProblem(;
7466
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
7567
)
76-
sweep_iter = sweep_iterator(
77-
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
68+
sweep_iter = SweepIterator(
69+
init_prob,
70+
nsweeps;
71+
nsites,
72+
outputlevel,
73+
factorize_kwargs,
74+
subspace_expand!_kwargs=(; eigen_kwargs=factorize_kwargs),
75+
sweep_kwargs...,
7876
)
79-
prob = sweep_solve(sweep_iter; outputlevel, kws...)
77+
prob = problem(sweep_solve!(sweep_iter))
8078
return eigenvalue(prob), state(prob)
8179
end
8280

83-
dmrg(args...; kws...) = eigsolve(args...; kws...)
81+
dmrg(operator, init_state; kwargs...) = eigsolve(operator, init_state; kwargs...)

src/solvers/extract.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
function extract(problem, region_iterator; sweep, trunc=(;), kws...)
2-
trunc = truncation_parameters(sweep; trunc...)
3-
region = current_region(region_iterator)
4-
psi = orthogonalize(state(problem), region)
1+
function extract!(region_iter::RegionIterator; subspace_algorithm="nothing")
2+
prob = problem(region_iter)
3+
region = current_region(region_iter)
4+
5+
psi = orthogonalize(state(prob), region)
56
local_state = prod(psi[v] for v in region)
6-
problem = set_state(problem, psi)
7-
problem, local_state = subspace_expand(
8-
problem, local_state, region_iterator; sweep, trunc, kws...
9-
)
10-
shifted_operator = position(operator(problem), state(problem), region)
11-
return set_operator(problem, shifted_operator), local_state
7+
8+
prob.state = psi
9+
10+
_, local_state = subspace_expand!(region_iter, local_state; subspace_algorithm)
11+
shifted_operator = position(operator(prob), state(prob), region)
12+
13+
prob.operator = shifted_operator
14+
15+
return region_iter, local_state
1216
end

0 commit comments

Comments
 (0)