Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.11"
version = "0.1.12"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
3 changes: 3 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ include("lazynameddimsarrays.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("contract_network.jl")
include("abstract_problem.jl")
include("iterators.jl")
include("adapters.jl")

end
1 change: 1 addition & 0 deletions src/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abstract type AbstractProblem end
45 changes: 45 additions & 0 deletions src/adapters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator

Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
compute!(adapter::IncrementOnly) = adapter

IncrementOnly(adapter::IncrementOnly) = adapter

"""
struct EachRegion{SweepIterator} <: AbstractNetworkIterator

Adapter that flattens each region iterator in the parent sweep iterator into a single
iterator.
"""
struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator
parent::SI
end

# In keeping with Julia convention.
eachregion(iter::SweepIterator) = EachRegion(iter)

# Essential definitions
function islaststep(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
return islaststep(adapter.parent) && islaststep(region_iter)
end
function increment!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
return adapter
end
function compute!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
compute!(region_iter)
return adapter
end
170 changes: 170 additions & 0 deletions src/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
abstract type AbstractNetworkIterator

A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

# We use greater than or equals here as we increment the state at the start of the iteration
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)

function Base.iterate(iterator::AbstractNetworkIterator, init = true)
# The assumption is that first "increment!" is implicit, therefore we must skip the
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
# defined when length < 1,
init || islaststep(iterator) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(iterator)
rv = compute!(iterator)
return rv, false
end

increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)}))
compute!(iterator::AbstractNetworkIterator) = iterator

step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
function step!(f, iterator::AbstractNetworkIterator)
compute!(iterator)
f(iterator)
increment!(iterator)
return iterator
end

#
# RegionIterator
#
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int
const which_sweep::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
if isempty(region_plan)
throw(ArgumentError("Cannot construct a region iterator with 0 elements."))
end
return new{P, R}(problem, region_plan, 1, sweep)
end
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

state(region_iter::RegionIterator) = region_iter.which_region
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)

problem(region_iter::RegionIterator) = region_iter.problem

function current_region_plan(region_iter::RegionIterator)
return region_iter.region_plan[region_iter.which_region]
end

function current_region(region_iter::RegionIterator)
region, _ = current_region_plan(region_iter)
return region
end

function region_kwargs(region_iter::RegionIterator)
_, kwargs = current_region_plan(region_iter)
return kwargs
end
function region_kwargs(f::Function, iter::RegionIterator)
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
end

function prev_region(region_iter::RegionIterator)
state(region_iter) <= 1 && return nothing
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
return prev
end

function next_region(region_iter::RegionIterator)
islaststep(region_iter) && return nothing
next, _ = region_iter.region_plan[region_iter.which_region + 1]
return next
end

#
# Functions associated with RegionIterator
#
function increment!(region_iter::RegionIterator)
region_iter.which_region += 1
return region_iter
end

function compute!(iter::RegionIterator)
extract!(iter; region_kwargs(extract!, iter)...)
update!(iter; region_kwargs(update!, iter)...)
insert!(iter; region_kwargs(insert!, iter)...)

return iter
end

region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)

#
# SweepIterator
#

mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
region_iter::RegionIterator{Problem}
sweep_kwargs::Iterators.Stateful{Iter}
which_sweep::Int
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)

first_state = Iterators.peel(stateful_sweep_kwargs)

if isnothing(first_state)
throw(ArgumentError("Cannot construct a sweep iterator with 0 elements."))
end

first_kwargs, _ = first_state
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)

return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
end
end

islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))

region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))

state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
function increment!(sweep_iter::SweepIterator)
sweep_iter.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
update_region_iterator!(sweep_iter; sweep_kwargs...)
return sweep_iter
end

function update_region_iterator!(iterator::SweepIterator; kwargs...)
sweep = state(iterator)
iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...)
return iterator
end

function compute!(sweep_iter::SweepIterator)
for _ in sweep_iter.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
return
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
Loading
Loading