Skip to content
2 changes: 2 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ include("lazynameddimsarrays.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("contract_network.jl")
include("abstract_problem.jl")
include("iterators.jl")

end
1 change: 1 addition & 0 deletions src/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abstract type AbstractProblem end
173 changes: 173 additions & 0 deletions src/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
abstract type AbstractNetworkIterator

A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

# We use greater than or equals here as we increment the state at the start of the iteration
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)

function Base.iterate(iterator::AbstractNetworkIterator, init = true)
# The assumption is that first "increment!" is implicit, therefore we must skip the
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
# defined when length < 1,
init || islaststep(iterator) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(iterator)
rv = compute!(iterator)
return rv, false
end

function increment! end
compute!(iterator::AbstractNetworkIterator) = iterator

step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
function step!(f, iterator::AbstractNetworkIterator)
compute!(iterator)
f(iterator)
increment!(iterator)
return iterator
end

#
# RegionIterator
#
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int
const which_sweep::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
if length(region_plan) == 0
throw(BoundsError("Cannot construct a region iterator with 0 elements."))
end
return new{P, R}(problem, region_plan, 1, sweep)
end
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

function new_region_iterator(iterator::RegionIterator; sweep_kwargs...)
return RegionIterator(iterator.problem; sweep_kwargs...)
end

state(region_iter::RegionIterator) = region_iter.which_region
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)

problem(region_iter::RegionIterator) = region_iter.problem

function current_region_plan(region_iter::RegionIterator)
return region_iter.region_plan[region_iter.which_region]
end

function current_region(region_iter::RegionIterator)
region, _ = current_region_plan(region_iter)
return region
end

function region_kwargs(region_iter::RegionIterator)
_, kwargs = current_region_plan(region_iter)
return kwargs
end
function region_kwargs(f::Function, iter::RegionIterator)
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
end

function prev_region(region_iter::RegionIterator)
state(region_iter) <= 1 && return nothing
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
return prev
end

function next_region(region_iter::RegionIterator)
islaststep(region_iter) && return nothing
next, _ = region_iter.region_plan[region_iter.which_region + 1]
return next
end

#
# Functions associated with RegionIterator
#
function increment!(region_iter::RegionIterator)
region_iter.which_region += 1
return region_iter
end

function compute!(iter::RegionIterator)
_, local_state = extract!(iter; region_kwargs(extract!, iter)...)
_, local_state = update!(iter, local_state; region_kwargs(update!, iter)...)
insert!(iter, local_state; region_kwargs(insert!, iter)...)

return iter
end

region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)

#
# SweepIterator
#

mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
region_iter::RegionIterator{Problem}
sweep_kwargs::Iterators.Stateful{Iter}
which_sweep::Int
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
first_state = Iterators.peel(stateful_sweep_kwargs)

if isnothing(first_state)
throw(BoundsError("Cannot construct a sweep iterator with 0 elements."))
end

first_kwargs, _ = first_state
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)

return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
end
end

islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))

region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))

state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
function increment!(sweep_iter::SweepIterator)
sweep_iter.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
update_region_iterator!(sweep_iter; sweep_kwargs...)
return sweep_iter
end

function update_region_iterator!(iterator::SweepIterator; kwargs...)
sweep = state(iterator)
iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...)
return iterator
end

function compute!(sweep_iter::SweepIterator)
for _ in sweep_iter.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
return
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
161 changes: 161 additions & 0 deletions test/test_iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using Test: @test, @testset, @test_throws
import ITensorNetworksNext as ITensorNetworks
using .ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion

module TestIteratorUtils

import ITensorNetworksNext as ITensorNetworks
using .ITensorNetworks

struct TestProblem <: ITensorNetworks.AbstractProblem
data::Vector{Int}
end
ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)]
function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem})
kwargs = ITensorNetworks.region_kwargs(iter)
push!(ITensorNetworks.problem(iter).data, kwargs.val)
return iter
end


mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator
state::Int
max::Int
output::Vector{Int}
end

ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1
Base.length(TI::TestIterator) = TI.max
ITensorNetworks.state(TI::TestIterator) = TI.state
function ITensorNetworks.compute!(TI::TestIterator)
push!(TI.output, ITensorNetworks.state(TI))
return TI
end

mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator
parent::TestIterator
end

Base.length(SA::SquareAdapter) = length(SA.parent)
ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent)
ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent)
function ITensorNetworks.compute!(SA::SquareAdapter)
ITensorNetworks.compute!(SA.parent)
return last(SA.parent.output)^2
end

end

@testset "Iterators" begin

import .TestIteratorUtils

@testset "`AbstractNetworkIterator` Interface" begin

@testset "Edge cases" begin
TI = TestIteratorUtils.TestIterator(1, 1, [])
cb = []
@test islaststep(TI)
for _ in TI
@test islaststep(TI)
push!(cb, state(TI))
end
@test length(cb) == 1
@test length(TI.output) == 1
@test only(cb) == 1

prob = TestIteratorUtils.TestProblem([])
@test_throws BoundsError SweepIterator(prob, 0)
@test_throws BoundsError RegionIterator(prob, [], 1)
end

TI = TestIteratorUtils.TestIterator(1, 4, [])

@test !islaststep((TI))

# First iterator should compute only
rv, st = iterate(TI)
@test !islaststep((TI))
@test !st
@test rv === TI
@test length(TI.output) == 1
@test only(TI.output) == 1
@test state(TI) == 1
@test !st

rv, st = iterate(TI, st)
@test !islaststep((TI))
@test !st
@test length(TI.output) == 2
@test state(TI) == 2
@test TI.output == [1, 2]

increment!(TI)
@test !islaststep((TI))
@test state(TI) == 3
@test length(TI.output) == 2
@test TI.output == [1, 2]

compute!(TI)
@test !islaststep((TI))
@test state(TI) == 3
@test length(TI.output) == 3
@test TI.output == [1, 2, 3]

# Final Step
iterate(TI, false)
@test islaststep((TI))
@test state(TI) == 4
@test length(TI.output) == 4
@test TI.output == [1, 2, 3, 4]

@test iterate(TI, false) === nothing

TI = TestIteratorUtils.TestIterator(1, 5, [])

cb = []

for _ in TI
@test length(cb) == length(TI.output) - 1
@test cb == (TI.output)[1:(end - 1)]
push!(cb, state(TI))
@test cb == TI.output
end

@test islaststep((TI))
@test length(TI.output) == 5
@test length(cb) == 5
@test cb == TI.output


TI = TestIteratorUtils.TestIterator(1, 5, [])
end

@testset "Adapters" begin
TI = TestIteratorUtils.TestIterator(1, 5, [])
SA = TestIteratorUtils.SquareAdapter(TI)

@testset "Generic" begin

i = 0
for rv in SA
i += 1
@test rv isa Int
@test rv == i^2
@test state(SA) == i
end

@test islaststep((SA))

TI = TestIteratorUtils.TestIterator(1, 5, [])
SA = TestIteratorUtils.SquareAdapter(TI)

SA_c = collect(SA)

@test SA_c isa Vector
@test length(SA_c) == 5
@test SA_c == [1, 4, 9, 16, 25]

end
end
end
Loading