Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7a51966
Add `Problem` as type parameter to `SweepIterator`
Sep 22, 2025
245182b
Format test files and improve comparisons for readabilty on failure
Sep 24, 2025
7af4b25
Redesign iterator interface by introducing AbstractNetworkIterator ab…
Sep 24, 2025
c0ae5d0
Add `EachRegion` adapter that wraps `RegionIterator`, behaving the sa…
Sep 25, 2025
3b9d0af
Add unit tests for the `AbstractNetworkIterator` interface
Sep 30, 2025
4ef4e75
Rename `done` to `laststep` to better reflect the when it evalutes to…
Sep 30, 2025
e112eb4
Rename `previous_region` to `prev_region` to better align with julia …
Sep 30, 2025
da360e0
Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some vari…
Oct 1, 2025
8bfc483
Make `extract` and `subspace_expand` mutating
Oct 3, 2025
1ef8498
Make `update` mutable
Oct 3, 2025
0a6e891
Make `insert` mutable
Oct 3, 2025
0653c47
First implementation of an `options` system.
Oct 3, 2025
d77321e
Simplify options interface to a single function `default_kwargs`.
Oct 6, 2025
aff14c7
Put calls to `extract!` etc in `compute!` function directly
Oct 6, 2025
4b21cc9
Refactor the region plan generating code.
Oct 7, 2025
e71512f
Have `dmrg` take a strict number of arguments
Oct 7, 2025
a4ce308
Purge non-mutating field setter functions.
Oct 7, 2025
a8b2c51
Use `current_kwargs` for getting kwargs from `RegionIterator`
Oct 7, 2025
18a8503
Introduce defaults using `default_kwargs` and be stricter about which…
Oct 7, 2025
0c9022c
Swap order of local_state and region_iter args
Oct 7, 2025
a9be11e
Add some unit tests for the defaults
Oct 7, 2025
4d52088
Rename file options.jl -> test_default_kwargs.jl
Oct 7, 2025
613d533
Fix `euler_sweep` returning kwargs not as `NamedTuple`
Oct 7, 2025
20bf783
The `sweep_solve` callbacks now get called without any keyword argume…
Oct 7, 2025
568c631
Some minor refactoring of the iterators.
Oct 7, 2025
fed9137
The `EachRegion` adapter now flattens the nested Sweep/Region iterato…
Oct 9, 2025
4ce453e
Add tests for `EachRegion` and `eachregion` wrapper functions
Oct 9, 2025
c59a9c5
Rename `laststep` -> `islaststep` in fitting with Julia conventions.
Oct 9, 2025
62195b6
Overhaul `default_kwargs` such that it mirrors the function signature…
Oct 9, 2025
917f2f1
Rename `NoComputeStep` to `IncrementOnly`
Oct 9, 2025
112d55e
Remove @info statement and fix bug with `astypes` not promoting corre…
Oct 10, 2025
0a9f127
Update `default_kwargs` tests.
Oct 10, 2025
e35f325
Remove stray `end` from `adapters.jl`.
Oct 14, 2025
6a8cdb1
Fix typo in docstring of `EachRegion` adapter.
jack-dunham Oct 14, 2025
9760de1
Function `reverse_regions` is now more concise.
Oct 14, 2025
26ece7b
Use explicit imports in `default_kwargs.jl`
Oct 14, 2025
340d805
Fix test imports and broken tests in `test_iterators.jl`.
Oct 14, 2025
f89c379
Merge branch 'network_solvers' of https://github.com/jack-dunham/ITen…
Oct 14, 2025
6a33f29
Rename @default_kwargs -> @define_default_kwargs
Oct 14, 2025
b4bcb93
Remove `astypes` option from `@define_default_kwargs`.
Oct 14, 2025
624f964
Update `default_kwargs` tests.
Oct 14, 2025
bd35f09
Add `sweep_solve` method for `EachRegion` adapter.
Oct 14, 2025
0b5314d
Add `@with_kwargs` macro which automatically splats `default_kwargs` …
Oct 14, 2025
a58ec92
Make use of `@with_kwargs` macro make code more concise.
Oct 14, 2025
b72a08f
The fallback default callback functions now no longer accept `kwargs.…
Oct 15, 2025
c5de5c4
Test fix: tests founds in sub-directories are now actually ran when i…
Oct 15, 2025
2788057
Skip broken tests for now
Oct 15, 2025
33b9e28
Rename `sweep_solve` -> `sweep_solve!` to obey convention
Oct 15, 2025
dedd82e
The `EachRegion` adapter now returns itself from `iterate` instead of…
Oct 15, 2025
d39f09e
The `sweep_solve!` function now always returns the type of the input …
Oct 15, 2025
3f5c97c
Mutating functions now return the first argument before any additiona…
Oct 15, 2025
7ad3138
Remove depreciated `solvers` code and tests from old interface
Oct 16, 2025
60235bc
Method `subspace_expand!(::Backend"densitymatrix")` now defines kwarg…
Oct 16, 2025
da3ad27
Solvers code now no longer relies on `default_kwargs` system
Oct 16, 2025
8725370
Remove `default_kwargs` related to source files
Oct 16, 2025
8afce8a
Delete stale include
mtfishman Oct 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/solvers/adapters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,21 @@ iterator which outputs a tuple of the form (current_region, current_region_kwarg
at each step.
"""
region_tuples(R::RegionIterator) = TupleRegionIterator(R)

"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator}

Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

done(NC::PauseAfterIncrement) = done(NC.parent)
state(NC::PauseAfterIncrement) = state(NC.parent)
increment!(NC::PauseAfterIncrement) = increment!(NC.parent)
compute!(NC::PauseAfterIncrement) = NC

PauseAfterIncrement(NC::PauseAfterIncrement) = NC
39 changes: 20 additions & 19 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ operator(A::ApplyExpProblem) = A.operator
state(A::ApplyExpProblem) = A.state
current_exponent(A::ApplyExpProblem) = A.current_exponent
function current_time(A::ApplyExpProblem)
t = im*A.current_exponent
t = im * A.current_exponent
return iszero(imag(t)) ? real(t) : t
end

Expand All @@ -36,41 +36,42 @@ function update(
iszero(abs(exponent_step)) && return prob, local_state

local_state, info = solver(
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
x -> optimal_map(operator(prob), x), exponent_step, local_state; kws...
)
if nsites==1
if nsites == 1
curr_reg = current_region(region_iterator)
next_reg = next_region(region_iterator)
if !isnothing(next_reg) && next_reg != curr_reg
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
v1, v2 = src(next_edge), dst(next_edge)
psi = copy(state(prob))
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1]*R_t
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1] * R_t
end
end

prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
prob = set_current_exponent(prob, current_exponent(prob) + exponent_step)

return prob, local_state
end

function sweep_callback(
problem::ApplyExpProblem;
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description="exponent",
outputlevel,
sweep,
nsweeps,
process_time=identity,
kws...,
kwargs...,
)
if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
" Current %s = %s, ",
exponent_description,
process_time(current_exponent(the_problem))
)
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
println()
flush(stdout)
end
Expand All @@ -88,9 +89,10 @@ function applyexp(
kws...,
)
exponent_steps = diff([zero(eltype(exponents)); exponents])
# exponent_steps = diff(exponents)
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
sweep_iter = sweep_iterator(init_prob, kws_array)
sweep_iter = SweepIterator(init_prob, kws_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
return state(converged_prob)
end
Expand All @@ -111,11 +113,10 @@ function time_evolve(
time_points,
init_state;
process_time=process_real_times,
sweep_callback=(
a...; k...
)->sweep_callback(a...; exponent_description="time", process_time, k...),
sweep_callback=(a...; k...) ->
default_sweep_callback(a...; exponent_description="time", process_time, k...),
kws...,
)
exponents = [-im*t for t in time_points]
exponents = [-im * t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
end
16 changes: 10 additions & 6 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function update(
solver=eigsolve_solver,
kws...,
)
eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...)
eigval, local_state = solver(ψ -> optimal_map(operator(prob), ψ), local_state; kws...)
prob = set_eigenvalue(prob, eigval)
if outputlevel >= 2
@printf(
Expand All @@ -44,12 +44,16 @@ function update(
return prob, local_state
end

function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel
)
if outputlevel >= 1
if nsweeps >= 10
@printf("After sweep %02d/%d ", sweep, nsweeps)
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
if length(sweep_iterator) >= 10
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
else
@printf("After sweep %d/%d ", sweep, nsweeps)
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
Expand All @@ -73,7 +77,7 @@ function eigsolve(
init_prob = EigsolveProblem(;
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
)
sweep_iter = sweep_iterator(
sweep_iter = SweepIterator(
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
)
prob = sweep_solve(sweep_iter; outputlevel, kws...)
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/fitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function fit_tensornetwork(
insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false)
common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs)
kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps]
sweep_iter = sweep_iterator(init_prob, kwargs_array)
sweep_iter = SweepIterator(init_prob, kwargs_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob))
end
Expand Down
188 changes: 123 additions & 65 deletions src/solvers/iterators.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,158 @@
#
# SweepIterator
#

mutable struct SweepIterator
sweep_kws
region_iter
which_sweep::Int
end

problem(S::SweepIterator) = problem(S.region_iter)
"""
abstract type AbstractNetworkIterator
Base.length(S::SweepIterator) = length(S.sweep_kws)
A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

function Base.iterate(S::SweepIterator, which=nothing)
if isnothing(which)
sweep_kws_state = iterate(S.sweep_kws)
else
sweep_kws_state = iterate(S.sweep_kws, which)
end
isnothing(sweep_kws_state) && return nothing
current_sweep_kws, next = sweep_kws_state
# We use greater than or equals here as we increment the state at the start of the iteration
done(NI::AbstractNetworkIterator) = state(NI) >= length(NI)

if !isnothing(which)
S.region_iter = region_iterator(
problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws...
)
end
S.which_sweep += 1
return S.region_iter, next
function Base.iterate(NI::AbstractNetworkIterator, init=true)
done(NI) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(NI)
rv = compute!(NI)
return rv, false
end

function sweep_iterator(problem, sweep_kws)
region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...)
return SweepIterator(sweep_kws, region_iter, 1)
end
function increment! end
compute!(NI::AbstractNetworkIterator) = NI

function sweep_iterator(problem, nsweeps::Integer; sweep_kws...)
return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps))
step!(NI::AbstractNetworkIterator) = step!(identity, NI)
function step!(f, NI::AbstractNetworkIterator)
compute!(NI)
f(NI)
increment!(NI)
return NI
end

#
# RegionIterator
#

@kwdef mutable struct RegionIterator{Problem,RegionPlan}
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int = 1
const sweep::Int
which_region::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R}
return new{P,R}(problem, region_plan, sweep, 1)
end
end

state(R::RegionIterator) = R.which_region
Base.length(R::RegionIterator) = length(R.region_plan)

problem(R::RegionIterator) = R.problem

current_region_plan(R::RegionIterator) = R.region_plan[R.which_region]
current_region(R::RegionIterator) = current_region_plan(R)[1]
region_kwargs(R::RegionIterator) = current_region_plan(R)[2]
function previous_region(R::RegionIterator)
R.which_region==1 ? nothing : R.region_plan[R.which_region - 1][1]

function current_region(R::RegionIterator)
region, _ = current_region_plan(R)
return region
end
function next_region(R::RegionIterator)
R.which_region==length(R.region_plan) ? nothing : R.region_plan[R.which_region + 1][1]

function current_region_kwargs(R::RegionIterator)
_, kwargs = current_region_plan(R)
return kwargs
end

function previous_region(R::RegionIterator)
state(R) <= 1 && return nothing
prev, _ = R.region_plan[R.which_region - 1]
return prev
end
is_last_region(R::RegionIterator) = isnothing(next_region(R))

function Base.iterate(R::RegionIterator, which=1)
R.which_region = which
region_plan_state = iterate(R.region_plan, which)
isnothing(region_plan_state) && return nothing
(current_region, region_kwargs), next = region_plan_state
R.problem = region_step(problem(R), R; region_kwargs...)
return R, next
function next_region(R::RegionIterator)
is_last_region(R) && return nothing
next, _ = R.region_plan[R.which_region + 1]
return next
end
is_last_region(R::RegionIterator) = length(R) === state(R)

#
# Functions associated with RegionIterator
#

function region_iterator(problem; sweep_kwargs...)
return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...))
function compute!(R::RegionIterator)
region_kwargs = current_region_kwargs(R)
R.problem = region_step(R; region_kwargs...)
return R
end
function increment!(R::RegionIterator)
R.which_region += 1
return R
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep, sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

function region_step(
problem,
region_iterator;
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
sweep,
kws...,
region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws...
)
problem, local_state = extract(problem, region_iterator; extract_kwargs..., sweep, kws...)
problem, local_state = update(
problem, local_state, region_iterator; update_kwargs..., kws...
)
problem = insert(problem, local_state, region_iterator; sweep, insert_kwargs..., kws...)
return problem
prob = problem(region_iterator)

sweep = region_iterator.sweep

prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...)
prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...)
prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...)
return prob
end

function region_plan(problem; kws...)
return euler_sweep(state(problem); kws...)
end

#
# SweepIterator
#

mutable struct SweepIterator{Problem} <: AbstractNetworkIterator
sweep_kws
region_iter::RegionIterator{Problem}
which_sweep::Int
function SweepIterator(problem, sweep_kws)
sweep_kws = Iterators.Stateful(sweep_kws)
first_kwargs, _ = Iterators.peel(sweep_kws)
region_iter = RegionIterator(problem; sweep=1, first_kwargs...)
return new{typeof(problem)}(sweep_kws, region_iter, 1)
end
end

done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws))

region_iterator(S::SweepIterator) = S.region_iter
problem(S::SweepIterator) = problem(region_iterator(S))

state(SR::SweepIterator) = SR.which_sweep
Base.length(S::SweepIterator) = length(S.sweep_kws)
function increment!(SR::SweepIterator)
SR.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(SR.sweep_kws)
SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...)
return SR
end

function compute!(SR::SweepIterator)
for _ in SR.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
Loading
Loading