Skip to content
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7a51966
Add `Problem` as type parameter to `SweepIterator`
Sep 22, 2025
245182b
Format test files and improve comparisons for readabilty on failure
Sep 24, 2025
7af4b25
Redesign iterator interface by introducing AbstractNetworkIterator ab…
Sep 24, 2025
c0ae5d0
Add `EachRegion` adapter that wraps `RegionIterator`, behaving the sa…
Sep 25, 2025
3b9d0af
Add unit tests for the `AbstractNetworkIterator` interface
Sep 30, 2025
4ef4e75
Rename `done` to `laststep` to better reflect the when it evalutes to…
Sep 30, 2025
e112eb4
Rename `previous_region` to `prev_region` to better align with julia …
Sep 30, 2025
da360e0
Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some vari…
Oct 1, 2025
8bfc483
Make `extract` and `subspace_expand` mutating
Oct 3, 2025
1ef8498
Make `update` mutable
Oct 3, 2025
0a6e891
Make `insert` mutable
Oct 3, 2025
0653c47
First implementation of an `options` system.
Oct 3, 2025
d77321e
Simplify options interface to a single function `default_kwargs`.
Oct 6, 2025
aff14c7
Put calls to `extract!` etc in `compute!` function directly
Oct 6, 2025
4b21cc9
Refactor the region plan generating code.
Oct 7, 2025
e71512f
Have `dmrg` take a strict number of arguments
Oct 7, 2025
a4ce308
Purge non-mutating field setter functions.
Oct 7, 2025
a8b2c51
Use `current_kwargs` for getting kwargs from `RegionIterator`
Oct 7, 2025
18a8503
Introduce defaults using `default_kwargs` and be stricter about which…
Oct 7, 2025
0c9022c
Swap order of local_state and region_iter args
Oct 7, 2025
a9be11e
Add some unit tests for the defaults
Oct 7, 2025
4d52088
Rename file options.jl -> test_default_kwargs.jl
Oct 7, 2025
613d533
Fix `euler_sweep` returning kwargs not as `NamedTuple`
Oct 7, 2025
20bf783
The `sweep_solve` callbacks now get called without any keyword argume…
Oct 7, 2025
568c631
Some minor refactoring of the iterators.
Oct 7, 2025
fed9137
The `EachRegion` adapter now flattens the nested Sweep/Region iterato…
Oct 9, 2025
4ce453e
Add tests for `EachRegion` and `eachregion` wrapper functions
Oct 9, 2025
c59a9c5
Rename `laststep` -> `islaststep` in fitting with Julia conventions.
Oct 9, 2025
62195b6
Overhaul `default_kwargs` such that it mirrors the function signature…
Oct 9, 2025
917f2f1
Rename `NoComputeStep` to `IncrementOnly`
Oct 9, 2025
112d55e
Remove @info statement and fix bug with `astypes` not promoting corre…
Oct 10, 2025
0a9f127
Update `default_kwargs` tests.
Oct 10, 2025
e35f325
Remove stray `end` from `adapters.jl`.
Oct 14, 2025
6a8cdb1
Fix typo in docstring of `EachRegion` adapter.
jack-dunham Oct 14, 2025
9760de1
Function `reverse_regions` is now more concise.
Oct 14, 2025
26ece7b
Use explicit imports in `default_kwargs.jl`
Oct 14, 2025
340d805
Fix test imports and broken tests in `test_iterators.jl`.
Oct 14, 2025
f89c379
Merge branch 'network_solvers' of https://github.com/jack-dunham/ITen…
Oct 14, 2025
6a33f29
Rename @default_kwargs -> @define_default_kwargs
Oct 14, 2025
b4bcb93
Remove `astypes` option from `@define_default_kwargs`.
Oct 14, 2025
624f964
Update `default_kwargs` tests.
Oct 14, 2025
bd35f09
Add `sweep_solve` method for `EachRegion` adapter.
Oct 14, 2025
0b5314d
Add `@with_kwargs` macro which automatically splats `default_kwargs` …
Oct 14, 2025
a58ec92
Make use of `@with_kwargs` macro make code more concise.
Oct 14, 2025
b72a08f
The fallback default callback functions now no longer accept `kwargs.…
Oct 15, 2025
c5de5c4
Test fix: tests founds in sub-directories are now actually ran when i…
Oct 15, 2025
2788057
Skip broken tests for now
Oct 15, 2025
33b9e28
Rename `sweep_solve` -> `sweep_solve!` to obey convention
Oct 15, 2025
dedd82e
The `EachRegion` adapter now returns itself from `iterate` instead of…
Oct 15, 2025
d39f09e
The `sweep_solve!` function now always returns the type of the input …
Oct 15, 2025
3f5c97c
Mutating functions now return the first argument before any additiona…
Oct 15, 2025
7ad3138
Remove depreciated `solvers` code and tests from old interface
Oct 16, 2025
60235bc
Method `subspace_expand!(::Backend"densitymatrix")` now defines kwarg…
Oct 16, 2025
da3ad27
Solvers code now no longer relies on `default_kwargs` system
Oct 16, 2025
8725370
Remove `default_kwargs` related to source files
Oct 16, 2025
8afce8a
Delete stale include
mtfishman Oct 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ include("treetensornetworks/projttns/projttn.jl")
include("treetensornetworks/projttns/projttnsum.jl")
include("treetensornetworks/projttns/projouterprodttn.jl")

include("solvers/default_kwargs.jl")
include("solvers/local_solvers/eigsolve.jl")
include("solvers/local_solvers/exponentiate.jl")
include("solvers/local_solvers/runge_kutta.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

abstract type AbstractProblem end

set_truncation_info(P::AbstractProblem, args...; kws...) = P
set_truncation_info!(P::AbstractProblem, args...; kws...) = P
62 changes: 38 additions & 24 deletions src/solvers/adapters.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator}

#
# TupleRegionIterator
#
# Adapts outputs to be (region, region_kwargs) tuples
#
# More generic design? maybe just assuming RegionIterator
# or its outputs implement some interface function that
# generates each tuple?
#

mutable struct TupleRegionIterator{RegionIter}
region_iterator::RegionIter
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

region_iterator(T::TupleRegionIterator) = T.region_iterator
islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
state(adapter::IncrementOnly) = state(adapter.parent)
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
compute!(adapter::IncrementOnly) = adapter

function Base.iterate(T::TupleRegionIterator, which=1)
state = iterate(region_iterator(T), which)
isnothing(state) && return nothing
(current_region, region_kwargs) = current_region_plan(region_iterator(T))
return (current_region, region_kwargs), last(state)
end
IncrementOnly(adapter::IncrementOnly) = adapter

"""
region_tuples(R::RegionIterator)
struct EachRegion{SweepIterator} <: AbstractNetworkIterator

The `region_tuples` adapter converts a RegionIterator into an
iterator which outputs a tuple of the form (current_region, current_region_kwargs)
at each step.
Adapter that flattens each region iterator in the parent sweep iterator into a single
iterator.
"""
region_tuples(R::RegionIterator) = TupleRegionIterator(R)
struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator
parent::SI
end

# In keeping with Julia convention.
eachregion(iter::SweepIterator) = EachRegion(iter)

# Essential definitions
function islaststep(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
return islaststep(adapter.parent) && islaststep(region_iter)
end
function increment!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
return adapter
end
function compute!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
compute!(region_iter)
return adapter
end
100 changes: 51 additions & 49 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Printf: @printf
using Accessors: @set

@kwdef mutable struct ApplyExpProblem{State} <: AbstractProblem
operator
Expand All @@ -11,66 +10,69 @@ operator(A::ApplyExpProblem) = A.operator
state(A::ApplyExpProblem) = A.state
current_exponent(A::ApplyExpProblem) = A.current_exponent
function current_time(A::ApplyExpProblem)
t = im*A.current_exponent
t = im * A.current_exponent
return iszero(imag(t)) ? real(t) : t
end

set_operator(A::ApplyExpProblem, operator) = (@set A.operator = operator)
set_state(A::ApplyExpProblem, state) = (@set A.state = state)
set_current_exponent(A::ApplyExpProblem, exponent) = (@set A.current_exponent = exponent)

function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...)
return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...)
# Rename region_plan
function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...)
# The `exponent_step` kwarg for the `update!` function needs some pre-processing.
return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...)
end

function update(
prob::ApplyExpProblem,
local_state,
region_iterator;
function update!(
region_iter::RegionIterator{<:ApplyExpProblem},
local_state;
nsites,
exponent_step,
solver=runge_kutta_solver,
outputlevel,
kws...,
)
iszero(abs(exponent_step)) && return prob, local_state
prob = problem(region_iter)

if iszero(abs(exponent_step))
return region_iter, local_state
end

local_state, info = solver(
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
solver_kwargs = region_kwargs(solver, region_iter)

local_state, _ = solver(
x -> optimal_map(operator(prob), x), exponent_step, local_state; solver_kwargs...
)
if nsites==1
curr_reg = current_region(region_iterator)
next_reg = next_region(region_iterator)
if nsites == 1
curr_reg = current_region(region_iter)
next_reg = next_region(region_iter)
if !isnothing(next_reg) && next_reg != curr_reg
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
v1, v2 = src(next_edge), dst(next_edge)
psi = copy(state(prob))
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1]*R_t
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
R_t, _ = solver(
x -> optimal_map(shifted_operator, x), -exponent_step, R; solver_kwargs...
)
local_state = psi[v1] * R_t
end
end

prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
prob.current_exponent += exponent_step

return prob, local_state
return region_iter, local_state
end

function sweep_callback(
problem::ApplyExpProblem;
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description="exponent",
outputlevel,
sweep,
nsweeps,
outputlevel=0,
process_time=identity,
kws...,
)
if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
" Current %s = %s, ",
exponent_description,
process_time(current_exponent(the_problem))
)
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
println()
flush(stdout)
end
Expand All @@ -79,19 +81,20 @@ end
function applyexp(
init_prob::AbstractProblem,
exponents;
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
outputlevel=0,
nsites=1,
sweep_callback=default_sweep_callback,
order=4,
kws...,
nsites=2,
sweep_kwargs...,
)
exponent_steps = diff([zero(eltype(exponents)); exponents])
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
sweep_iter = sweep_iterator(init_prob, kws_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)

kws_array = [
(; order, nsites, sweep_kwargs..., exponent_step) for exponent_step in exponent_steps
]
sweep_iter = SweepIterator(init_prob, kws_array)

converged_prob = problem(sweep_solve!(sweep_callback, sweep_iter))

return state(converged_prob)
end

Expand All @@ -111,11 +114,10 @@ function time_evolve(
time_points,
init_state;
process_time=process_real_times,
sweep_callback=(
a...; k...
)->sweep_callback(a...; exponent_description="time", process_time, k...),
kws...,
sweep_callback=iter ->
default_sweep_callback(iter; exponent_description="time", process_time),
sweep_kwargs...,
)
exponents = [-im*t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
exponents = [-im * t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, sweep_kwargs...)
end
72 changes: 35 additions & 37 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Accessors: @set
using Printf: @printf
using ITensors: truncerror

Expand All @@ -14,42 +13,43 @@ state(E::EigsolveProblem) = E.state
operator(E::EigsolveProblem) = E.operator
max_truncerror(E::EigsolveProblem) = E.max_truncerror

set_operator(E::EigsolveProblem, operator) = (@set E.operator = operator)
set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue)
set_state(E::EigsolveProblem, state) = (@set E.state = state)
set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror)

function set_truncation_info(E::EigsolveProblem; spectrum=nothing)
function set_truncation_info!(E::EigsolveProblem; spectrum=nothing)
if !isnothing(spectrum)
E = set_max_truncerror(E, max(max_truncerror(E), truncerror(spectrum)))
E.max_truncerror = max(max_truncerror(E), truncerror(spectrum))
end
return E
end

function update(
prob::EigsolveProblem,
local_state,
region_iterator;
outputlevel,
function update!(
region_iter::RegionIterator{<:EigsolveProblem},
local_state;
outputlevel=0,
solver=eigsolve_solver,
kws...,
)
eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...)
prob = set_eigenvalue(prob, eigval)
prob = problem(region_iter)

eigval, local_state = solver(
ψ -> optimal_map(operator(prob), ψ), local_state; region_kwargs(solver, region_iter)...
)

prob.eigenvalue = eigval

if outputlevel >= 2
@printf(
" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob)
)
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
end
return prob, local_state
return region_iter, local_state
end

function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0
)
if outputlevel >= 1
if nsweeps >= 10
@printf("After sweep %02d/%d ", sweep, nsweeps)
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
if length(sweep_iterator) >= 10
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
else
@printf("After sweep %d/%d ", sweep, nsweeps)
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
Expand All @@ -60,24 +60,22 @@ function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, k
end

function eigsolve(
operator,
init_state;
nsweeps,
nsites=1,
outputlevel=0,
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
kws...,
operator, init_state; nsweeps, nsites=1, outputlevel=0, factorize_kwargs, sweep_kwargs...
)
init_prob = EigsolveProblem(;
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
)
sweep_iter = sweep_iterator(
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
sweep_iter = SweepIterator(
init_prob,
nsweeps;
nsites,
outputlevel,
factorize_kwargs,
subspace_expand!_kwargs=(; eigen_kwargs=factorize_kwargs),
sweep_kwargs...,
)
prob = sweep_solve(sweep_iter; outputlevel, kws...)
prob = problem(sweep_solve!(sweep_iter))
return eigenvalue(prob), state(prob)
end

dmrg(args...; kws...) = eigsolve(args...; kws...)
dmrg(operator, init_state; kwargs...) = eigsolve(operator, init_state; kwargs...)
24 changes: 14 additions & 10 deletions src/solvers/extract.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
function extract(problem, region_iterator; sweep, trunc=(;), kws...)
trunc = truncation_parameters(sweep; trunc...)
region = current_region(region_iterator)
psi = orthogonalize(state(problem), region)
function extract!(region_iter::RegionIterator; subspace_algorithm="nothing")
prob = problem(region_iter)
region = current_region(region_iter)

psi = orthogonalize(state(prob), region)
local_state = prod(psi[v] for v in region)
problem = set_state(problem, psi)
problem, local_state = subspace_expand(
problem, local_state, region_iterator; sweep, trunc, kws...
)
shifted_operator = position(operator(problem), state(problem), region)
return set_operator(problem, shifted_operator), local_state

prob.state = psi

_, local_state = subspace_expand!(region_iter, local_state; subspace_algorithm)
shifted_operator = position(operator(prob), state(prob), region)

prob.operator = shifted_operator

return region_iter, local_state
end
Loading
Loading