diff --git a/Project.toml b/Project.toml index 42df730..4eb513e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.11" +version = "0.1.12" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index b59c3bd..6e2a466 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -4,5 +4,8 @@ include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("contract_network.jl") +include("abstract_problem.jl") +include("iterators.jl") +include("adapters.jl") end diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl new file mode 100644 index 0000000..5a65e0a --- /dev/null +++ b/src/abstract_problem.jl @@ -0,0 +1 @@ +abstract type AbstractProblem end diff --git a/src/adapters.jl b/src/adapters.jl new file mode 100644 index 0000000..28318fb --- /dev/null +++ b/src/adapters.jl @@ -0,0 +1,45 @@ +""" + struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator + +Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the +process. This allows one to manually call a custom `compute!` or insert their own code it in +the loop body in place of `compute!`. +""" +struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator + parent::S +end + +islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) +increment!(adapter::IncrementOnly) = increment!(adapter.parent) +compute!(adapter::IncrementOnly) = adapter + +IncrementOnly(adapter::IncrementOnly) = adapter + +""" + struct EachRegion{SweepIterator} <: AbstractNetworkIterator + +Adapter that flattens each region iterator in the parent sweep iterator into a single +iterator. +""" +struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator + parent::SI +end + +# In keeping with Julia convention. +eachregion(iter::SweepIterator) = EachRegion(iter) + +# Essential definitions +function islaststep(adapter::EachRegion) + region_iter = region_iterator(adapter.parent) + return islaststep(adapter.parent) && islaststep(region_iter) +end +function increment!(adapter::EachRegion) + region_iter = region_iterator(adapter.parent) + islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) + return adapter +end +function compute!(adapter::EachRegion) + region_iter = region_iterator(adapter.parent) + compute!(region_iter) + return adapter +end diff --git a/src/iterators.jl b/src/iterators.jl new file mode 100644 index 0000000..62d5b21 --- /dev/null +++ b/src/iterators.jl @@ -0,0 +1,170 @@ +""" + abstract type AbstractNetworkIterator + +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins +with a call to `increment!` before executing `compute!`, however the initial call to +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that +this call is implict. Termination of the iterator is controlled by the function `done`. +""" +abstract type AbstractNetworkIterator end + +# We use greater than or equals here as we increment the state at the start of the iteration +islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) + +function Base.iterate(iterator::AbstractNetworkIterator, init = true) + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* + # define a method for increment! This way we avoid cases where one may wish to nest + # calls to different step! methods accidentaly incrementing multiple times. + init || increment!(iterator) + rv = compute!(iterator) + return rv, false +end + +increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) +compute!(iterator::AbstractNetworkIterator) = iterator + +step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) +function step!(f, iterator::AbstractNetworkIterator) + compute!(iterator) + f(iterator) + increment!(iterator) + return iterator +end + +# +# RegionIterator +# +""" + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator +""" +mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator + problem::Problem + region_plan::RegionPlan + which_region::Int + const which_sweep::Int + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if isempty(region_plan) + throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) + end + return new{P, R}(problem, region_plan, 1, sweep) + end +end + +function RegionIterator(problem; sweep, sweep_kwargs...) + plan = region_plan(problem; sweep_kwargs...) + return RegionIterator(problem, plan, sweep) +end + +state(region_iter::RegionIterator) = region_iter.which_region +Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) + +problem(region_iter::RegionIterator) = region_iter.problem + +function current_region_plan(region_iter::RegionIterator) + return region_iter.region_plan[region_iter.which_region] +end + +function current_region(region_iter::RegionIterator) + region, _ = current_region_plan(region_iter) + return region +end + +function region_kwargs(region_iter::RegionIterator) + _, kwargs = current_region_plan(region_iter) + return kwargs +end +function region_kwargs(f::Function, iter::RegionIterator) + return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) +end + +function prev_region(region_iter::RegionIterator) + state(region_iter) <= 1 && return nothing + prev, _ = region_iter.region_plan[region_iter.which_region - 1] + return prev +end + +function next_region(region_iter::RegionIterator) + islaststep(region_iter) && return nothing + next, _ = region_iter.region_plan[region_iter.which_region + 1] + return next +end + +# +# Functions associated with RegionIterator +# +function increment!(region_iter::RegionIterator) + region_iter.which_region += 1 + return region_iter +end + +function compute!(iter::RegionIterator) + extract!(iter; region_kwargs(extract!, iter)...) + update!(iter; region_kwargs(update!, iter)...) + insert!(iter; region_kwargs(insert!, iter)...) + + return iter +end + +region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) + +# +# SweepIterator +# + +mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator + region_iter::RegionIterator{Problem} + sweep_kwargs::Iterators.Stateful{Iter} + which_sweep::Int + function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} + stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) + + first_state = Iterators.peel(stateful_sweep_kwargs) + + if isnothing(first_state) + throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) + end + + first_kwargs, _ = first_state + region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) + end +end + +islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) + +region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter +problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) + +state(sweep_iter::SweepIterator) = sweep_iter.which_sweep +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) +function increment!(sweep_iter::SweepIterator) + sweep_iter.which_sweep += 1 + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) + update_region_iterator!(sweep_iter; sweep_kwargs...) + return sweep_iter +end + +function update_region_iterator!(iterator::SweepIterator; kwargs...) + sweep = state(iterator) + iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) + return iterator +end + +function compute!(sweep_iter::SweepIterator) + for _ in sweep_iter.region_iter + # TODO: Is it sensible to execute the default region callback function? + end + return +end + +# More basic constructor where sweep_kwargs are constant throughout sweeps +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) + # Initialize this to an empty RegionIterator + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) + return SweepIterator(problem, sweep_kwargs_iter) +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl new file mode 100644 index 0000000..a17c7be --- /dev/null +++ b/test/test_iterators.jl @@ -0,0 +1,221 @@ +using Test: @test, @testset, @test_throws +import ITensorNetworksNext as ITensorNetworks +using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion + +module TestIteratorUtils + + import ITensorNetworksNext as ITensorNetworks + using .ITensorNetworks + + struct TestProblem <: ITensorNetworks.AbstractProblem + data::Vector{Int} + end + ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] + function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) + kwargs = ITensorNetworks.region_kwargs(iter) + push!(ITensorNetworks.problem(iter).data, kwargs.val) + return iter + end + + + mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator + state::Int + max::Int + output::Vector{Int} + end + + ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 + Base.length(TI::TestIterator) = TI.max + ITensorNetworks.state(TI::TestIterator) = TI.state + function ITensorNetworks.compute!(TI::TestIterator) + push!(TI.output, ITensorNetworks.state(TI)) + return TI + end + + mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator + parent::TestIterator + end + + Base.length(SA::SquareAdapter) = length(SA.parent) + ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) + ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) + function ITensorNetworks.compute!(SA::SquareAdapter) + ITensorNetworks.compute!(SA.parent) + return last(SA.parent.output)^2 + end + +end + +@testset "Iterators" begin + + import .TestIteratorUtils + + @testset "`AbstractNetworkIterator` Interface" begin + + @testset "Edge cases" begin + TI = TestIteratorUtils.TestIterator(1, 1, []) + cb = [] + @test islaststep(TI) + for _ in TI + @test islaststep(TI) + push!(cb, state(TI)) + end + @test length(cb) == 1 + @test length(TI.output) == 1 + @test only(cb) == 1 + + prob = TestIteratorUtils.TestProblem([]) + @test_throws ArgumentError SweepIterator(prob, 0) + @test_throws ArgumentError RegionIterator(prob, [], 1) + end + + TI = TestIteratorUtils.TestIterator(1, 4, []) + + @test !islaststep((TI)) + + # First iterator should compute only + rv, st = iterate(TI) + @test !islaststep((TI)) + @test !st + @test rv === TI + @test length(TI.output) == 1 + @test only(TI.output) == 1 + @test state(TI) == 1 + @test !st + + rv, st = iterate(TI, st) + @test !islaststep((TI)) + @test !st + @test length(TI.output) == 2 + @test state(TI) == 2 + @test TI.output == [1, 2] + + increment!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 2 + @test TI.output == [1, 2] + + compute!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 3 + @test TI.output == [1, 2, 3] + + # Final Step + iterate(TI, false) + @test islaststep((TI)) + @test state(TI) == 4 + @test length(TI.output) == 4 + @test TI.output == [1, 2, 3, 4] + + @test iterate(TI, false) === nothing + + TI = TestIteratorUtils.TestIterator(1, 5, []) + + cb = [] + + for _ in TI + @test length(cb) == length(TI.output) - 1 + @test cb == (TI.output)[1:(end - 1)] + push!(cb, state(TI)) + @test cb == TI.output + end + + @test islaststep((TI)) + @test length(TI.output) == 5 + @test length(cb) == 5 + @test cb == TI.output + + + TI = TestIteratorUtils.TestIterator(1, 5, []) + end + + @testset "Adapters" begin + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + @testset "Generic" begin + + i = 0 + for rv in SA + i += 1 + @test rv isa Int + @test rv == i^2 + @test state(SA) == i + end + + @test islaststep((SA)) + + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + SA_c = collect(SA) + + @test SA_c isa Vector + @test length(SA_c) == 5 + @test SA_c == [1, 4, 9, 16, 25] + + end + + @testset "IncrementOnly" begin + TI = TestIteratorUtils.TestIterator(1, 5, []) + NI = IncrementOnly(TI) + + NI_c = [] + + for _ in IncrementOnly(TI) + push!(NI_c, state(TI)) + end + + @test length(NI_c) == 5 + @test isempty(TI.output) + end + + @testset "EachRegion" begin + prob = TestIteratorUtils.TestProblem([]) + prob_region = TestIteratorUtils.TestProblem([]) + + SI = SweepIterator(prob, 5) + SI_region = SweepIterator(prob_region, 5) + + callback = [] + callback_region = [] + + let i = 1 + for _ in SI + push!(callback, i) + i += 1 + end + end + + @test length(callback) == 5 + + let i = 1 + for _ in eachregion(SI_region) + push!(callback_region, i) + i += 1 + end + end + + @test length(callback_region) == 10 + + @test prob.data == prob_region.data + + @test prob.data[1:2:end] == fill(1, 5) + @test prob.data[2:2:end] == fill(2, 5) + + + let i = 1, prob = TestIteratorUtils.TestProblem([]) + SI = SweepIterator(prob, 1) + cb = [] + for _ in eachregion(SI) + push!(cb, i) + i += 1 + end + @test length(cb) == 2 + end + + end + end +end