Skip to content
2 changes: 2 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ include("lazynameddimsarrays.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("contract_network.jl")
include("abstract_problem.jl")
include("iterators.jl")

end
1 change: 1 addition & 0 deletions src/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abstract type AbstractProblem end
179 changes: 179 additions & 0 deletions src/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
abstract type AbstractNetworkIterator

A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

# We use greater than or equals here as we increment the state at the start of the iteration
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)

function Base.iterate(iterator::AbstractNetworkIterator, init = true)
# The assumption is that first "increment!" is implicit, therefore we must skip the
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
# defined when length < 1,
init || islaststep(iterator) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(iterator)
rv = compute!(iterator)
return rv, false
end

increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)}))
compute!(iterator::AbstractNetworkIterator) = iterator

step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
function step!(f, iterator::AbstractNetworkIterator)
compute!(iterator)
f(iterator)
increment!(iterator)
return iterator
end

#
# RegionIterator
#
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int
const which_sweep::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
try
first(region_plan)
catch e
if e isa BoundsError
throw(ArgumentError("Cannot construct a region iterator with 0 elements."))
end
rethrow()
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always a bit weary to use try except in extenuating circumstances, when is it not enough to use something like isempty(region_plan)? It seems reasonable to expect whatever is input as region_plan to define some functions like isempty, length, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is it really a problem to have an empty RegionIterator? Can't that just mean that that if you iterate it it doesn't do anything?

Copy link
Contributor Author

@jack-dunham jack-dunham Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like try as it allows one to both print a more descriptive error higher in the callstack as well as the error that would be inevitably be thrown later, but Ill be honest I don't have that much formal programming knowledge of how the statements should be used, so happy to use some other pattern if you think it would be more suitable.

Constructing empty AbstractNetworkIterators is undefined as the first iteration is implicit. This is a consequence of having both of the following behaviors:

  1. Work done during iteration. i.e. allowing for _ in iter end to do the full computation
  2. An AbstractNetworkIterator having a well defined state during the callback.

Might be easier to discuss this in person if you are uncomfortable with disallowing empty AbstractNetworkIterator, but I believe it to be fundamental to allow for the above two behaviors (that we already agreed were good behaviors).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think this would be easier to discuss in person, I've gotten a little bit lost around those edge cases in the design. If you want to write a custom error message I don't see why something closer to the previous design doesn't suffice, i.e.:

        if isempty(region_plan)
            throw(ArgumentError("Cannot construct a region iterator with 0 elements."))
        end

but maybe I'm missing something.

Copy link
Member

@mtfishman mtfishman Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the issue that isempty(region_plan) might sometimes throw an error? If so, in which cases would it throw an error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not missing anything, I am overcomplicating things!

Copy link
Member

@mtfishman mtfishman Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the concern is that region_plan might be an iterator that doesn't have a length/size defined, maybe we can catch that more explicitly by checking Base.IteratorSize(region_plan) == Base.SizeUnknown() (https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration). (But also that seems like a strange corner case that we can deal with if it comes up.)

return new{P, R}(problem, region_plan, 1, sweep)
end
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

function new_region_iterator(iterator::RegionIterator; sweep_kwargs...)
return RegionIterator(iterator.problem; sweep_kwargs...)
end

state(region_iter::RegionIterator) = region_iter.which_region
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)

problem(region_iter::RegionIterator) = region_iter.problem

function current_region_plan(region_iter::RegionIterator)
return region_iter.region_plan[region_iter.which_region]
end

function current_region(region_iter::RegionIterator)
region, _ = current_region_plan(region_iter)
return region
end

function region_kwargs(region_iter::RegionIterator)
_, kwargs = current_region_plan(region_iter)
return kwargs
end
function region_kwargs(f::Function, iter::RegionIterator)
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
end

function prev_region(region_iter::RegionIterator)
state(region_iter) <= 1 && return nothing
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
return prev
end

function next_region(region_iter::RegionIterator)
islaststep(region_iter) && return nothing
next, _ = region_iter.region_plan[region_iter.which_region + 1]
return next
end

#
# Functions associated with RegionIterator
#
function increment!(region_iter::RegionIterator)
region_iter.which_region += 1
return region_iter
end

function compute!(iter::RegionIterator)
_, local_state = extract!(iter; region_kwargs(extract!, iter)...)
_, local_state = update!(iter, local_state; region_kwargs(update!, iter)...)
insert!(iter, local_state; region_kwargs(insert!, iter)...)

return iter
end

region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)

#
# SweepIterator
#

mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
region_iter::RegionIterator{Problem}
sweep_kwargs::Iterators.Stateful{Iter}
which_sweep::Int
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)

first_state = Iterators.peel(stateful_sweep_kwargs)

if isnothing(first_state)
throw(ArgumentError("Cannot construct a sweep iterator with 0 elements."))
end

first_kwargs, _ = first_state
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)

return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
end
end

islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))

region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))

state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
function increment!(sweep_iter::SweepIterator)
sweep_iter.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
update_region_iterator!(sweep_iter; sweep_kwargs...)
return sweep_iter
end

function update_region_iterator!(iterator::SweepIterator; kwargs...)
sweep = state(iterator)
iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...)
return iterator
end

function compute!(sweep_iter::SweepIterator)
for _ in sweep_iter.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
return
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
161 changes: 161 additions & 0 deletions test/test_iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using Test: @test, @testset, @test_throws
import ITensorNetworksNext as ITensorNetworks
using .ITensorNetworks: RegionIterator, SweepIterator, compute!, increment!, islaststep, state

module TestIteratorUtils

import ITensorNetworksNext as ITensorNetworks
using .ITensorNetworks

struct TestProblem <: ITensorNetworks.AbstractProblem
data::Vector{Int}
end
ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)]
function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem})
kwargs = ITensorNetworks.region_kwargs(iter)
push!(ITensorNetworks.problem(iter).data, kwargs.val)
return iter
end


mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator
state::Int
max::Int
output::Vector{Int}
end

ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1
Base.length(TI::TestIterator) = TI.max
ITensorNetworks.state(TI::TestIterator) = TI.state
function ITensorNetworks.compute!(TI::TestIterator)
push!(TI.output, ITensorNetworks.state(TI))
return TI
end

mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator
parent::TestIterator
end

Base.length(SA::SquareAdapter) = length(SA.parent)
ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent)
ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent)
function ITensorNetworks.compute!(SA::SquareAdapter)
ITensorNetworks.compute!(SA.parent)
return last(SA.parent.output)^2
end

end

@testset "Iterators" begin

import .TestIteratorUtils

@testset "`AbstractNetworkIterator` Interface" begin

@testset "Edge cases" begin
TI = TestIteratorUtils.TestIterator(1, 1, [])
cb = []
@test islaststep(TI)
for _ in TI
@test islaststep(TI)
push!(cb, state(TI))
end
@test length(cb) == 1
@test length(TI.output) == 1
@test only(cb) == 1

prob = TestIteratorUtils.TestProblem([])
@test_throws ArgumentError SweepIterator(prob, 0)
@test_throws ArgumentError RegionIterator(prob, [], 1)
end

TI = TestIteratorUtils.TestIterator(1, 4, [])

@test !islaststep((TI))

# First iterator should compute only
rv, st = iterate(TI)
@test !islaststep((TI))
@test !st
@test rv === TI
@test length(TI.output) == 1
@test only(TI.output) == 1
@test state(TI) == 1
@test !st

rv, st = iterate(TI, st)
@test !islaststep((TI))
@test !st
@test length(TI.output) == 2
@test state(TI) == 2
@test TI.output == [1, 2]

increment!(TI)
@test !islaststep((TI))
@test state(TI) == 3
@test length(TI.output) == 2
@test TI.output == [1, 2]

compute!(TI)
@test !islaststep((TI))
@test state(TI) == 3
@test length(TI.output) == 3
@test TI.output == [1, 2, 3]

# Final Step
iterate(TI, false)
@test islaststep((TI))
@test state(TI) == 4
@test length(TI.output) == 4
@test TI.output == [1, 2, 3, 4]

@test iterate(TI, false) === nothing

TI = TestIteratorUtils.TestIterator(1, 5, [])

cb = []

for _ in TI
@test length(cb) == length(TI.output) - 1
@test cb == (TI.output)[1:(end - 1)]
push!(cb, state(TI))
@test cb == TI.output
end

@test islaststep((TI))
@test length(TI.output) == 5
@test length(cb) == 5
@test cb == TI.output


TI = TestIteratorUtils.TestIterator(1, 5, [])
end

@testset "Adapters" begin
TI = TestIteratorUtils.TestIterator(1, 5, [])
SA = TestIteratorUtils.SquareAdapter(TI)

@testset "Generic" begin

i = 0
for rv in SA
i += 1
@test rv isa Int
@test rv == i^2
@test state(SA) == i
end

@test islaststep((SA))

TI = TestIteratorUtils.TestIterator(1, 5, [])
SA = TestIteratorUtils.SquareAdapter(TI)

SA_c = collect(SA)

@test SA_c isa Vector
@test length(SA_c) == 5
@test SA_c == [1, 4, 9, 16, 25]

end
end
end