-
Notifications
You must be signed in to change notification settings - Fork 2
Reintroduce AbstractNetworkIterator and AbstractProblem interface
#17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
dc41a58
Reintroduce `AbstractNetworkIterator` and `AbstractProblem` interface
d739a6e
Alphabetise `using` statements.
jack-dunham 481089a
Include type signature in `increment!` definition.
jack-dunham b9386f2
Improve error handling of empty iterators
cf47282
Remove try catch statement in `RegionIterator` construction for empty…
5207a3a
Interface now assumes `local_state` is stored in the `AbstractProblem…
790fd68
Remove `new_region_iterator` function.
7e9ba0f
Add `adapters.jl` code
715946c
Bump version from 0.1.11 to 0.1.12
mtfishman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| abstract type AbstractProblem end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| """ | ||
| abstract type AbstractNetworkIterator | ||
|
|
||
| A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins | ||
| with a call to `increment!` before executing `compute!`, however the initial call to | ||
| `iterate` skips the `increment!` call as it is assumed the iterator is initalized such that | ||
| this call is implict. Termination of the iterator is controlled by the function `done`. | ||
| """ | ||
| abstract type AbstractNetworkIterator end | ||
|
|
||
| # We use greater than or equals here as we increment the state at the start of the iteration | ||
| islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) | ||
|
|
||
| function Base.iterate(iterator::AbstractNetworkIterator, init = true) | ||
| # The assumption is that first "increment!" is implicit, therefore we must skip the | ||
| # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not | ||
| # defined when length < 1, | ||
| init || islaststep(iterator) && return nothing | ||
| # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* | ||
| # define a method for increment! This way we avoid cases where one may wish to nest | ||
| # calls to different step! methods accidentaly incrementing multiple times. | ||
| init || increment!(iterator) | ||
| rv = compute!(iterator) | ||
| return rv, false | ||
| end | ||
|
|
||
| function increment! end | ||
| compute!(iterator::AbstractNetworkIterator) = iterator | ||
|
|
||
| step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) | ||
| function step!(f, iterator::AbstractNetworkIterator) | ||
| compute!(iterator) | ||
| f(iterator) | ||
| increment!(iterator) | ||
| return iterator | ||
| end | ||
|
|
||
| # | ||
| # RegionIterator | ||
| # | ||
| """ | ||
| struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator | ||
| """ | ||
| mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator | ||
| problem::Problem | ||
| region_plan::RegionPlan | ||
| which_region::Int | ||
| const which_sweep::Int | ||
| function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} | ||
| if length(region_plan) == 0 | ||
| throw(BoundsError("Cannot construct a region iterator with 0 elements.")) | ||
| end | ||
| return new{P, R}(problem, region_plan, 1, sweep) | ||
| end | ||
| end | ||
|
|
||
| function RegionIterator(problem; sweep, sweep_kwargs...) | ||
| plan = region_plan(problem; sweep_kwargs...) | ||
| return RegionIterator(problem, plan, sweep) | ||
| end | ||
|
|
||
| function new_region_iterator(iterator::RegionIterator; sweep_kwargs...) | ||
| return RegionIterator(iterator.problem; sweep_kwargs...) | ||
| end | ||
jack-dunham marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| state(region_iter::RegionIterator) = region_iter.which_region | ||
| Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) | ||
|
|
||
| problem(region_iter::RegionIterator) = region_iter.problem | ||
|
|
||
| function current_region_plan(region_iter::RegionIterator) | ||
| return region_iter.region_plan[region_iter.which_region] | ||
| end | ||
|
|
||
| function current_region(region_iter::RegionIterator) | ||
| region, _ = current_region_plan(region_iter) | ||
| return region | ||
| end | ||
|
|
||
| function region_kwargs(region_iter::RegionIterator) | ||
| _, kwargs = current_region_plan(region_iter) | ||
| return kwargs | ||
| end | ||
| function region_kwargs(f::Function, iter::RegionIterator) | ||
| return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) | ||
| end | ||
|
|
||
| function prev_region(region_iter::RegionIterator) | ||
| state(region_iter) <= 1 && return nothing | ||
| prev, _ = region_iter.region_plan[region_iter.which_region - 1] | ||
| return prev | ||
| end | ||
|
|
||
| function next_region(region_iter::RegionIterator) | ||
| islaststep(region_iter) && return nothing | ||
| next, _ = region_iter.region_plan[region_iter.which_region + 1] | ||
| return next | ||
| end | ||
|
|
||
| # | ||
| # Functions associated with RegionIterator | ||
| # | ||
| function increment!(region_iter::RegionIterator) | ||
| region_iter.which_region += 1 | ||
| return region_iter | ||
| end | ||
|
|
||
| function compute!(iter::RegionIterator) | ||
| _, local_state = extract!(iter; region_kwargs(extract!, iter)...) | ||
| _, local_state = update!(iter, local_state; region_kwargs(update!, iter)...) | ||
| insert!(iter, local_state; region_kwargs(insert!, iter)...) | ||
jack-dunham marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return iter | ||
| end | ||
|
|
||
| region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) | ||
|
|
||
| # | ||
| # SweepIterator | ||
| # | ||
|
|
||
| mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator | ||
| region_iter::RegionIterator{Problem} | ||
| sweep_kwargs::Iterators.Stateful{Iter} | ||
| which_sweep::Int | ||
| function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} | ||
| stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) | ||
| first_state = Iterators.peel(stateful_sweep_kwargs) | ||
|
|
||
| if isnothing(first_state) | ||
| throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) | ||
| end | ||
|
|
||
| first_kwargs, _ = first_state | ||
| region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) | ||
|
|
||
| return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) | ||
| end | ||
| end | ||
|
|
||
| islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) | ||
|
|
||
| region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter | ||
| problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) | ||
|
|
||
| state(sweep_iter::SweepIterator) = sweep_iter.which_sweep | ||
| Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) | ||
| function increment!(sweep_iter::SweepIterator) | ||
| sweep_iter.which_sweep += 1 | ||
| sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) | ||
| update_region_iterator!(sweep_iter; sweep_kwargs...) | ||
| return sweep_iter | ||
| end | ||
|
|
||
| function update_region_iterator!(iterator::SweepIterator; kwargs...) | ||
| sweep = state(iterator) | ||
| iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...) | ||
| return iterator | ||
| end | ||
|
|
||
| function compute!(sweep_iter::SweepIterator) | ||
| for _ in sweep_iter.region_iter | ||
| # TODO: Is it sensible to execute the default region callback function? | ||
| end | ||
| return | ||
| end | ||
|
|
||
| # More basic constructor where sweep_kwargs are constant throughout sweeps | ||
| function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) | ||
| # Initialize this to an empty RegionIterator | ||
| sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) | ||
| return SweepIterator(problem, sweep_kwargs_iter) | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| using Test: @test, @testset, @test_throws | ||
| import ITensorNetworksNext as ITensorNetworks | ||
| using .ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion | ||
jack-dunham marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| module TestIteratorUtils | ||
|
|
||
| import ITensorNetworksNext as ITensorNetworks | ||
| using .ITensorNetworks | ||
|
|
||
| struct TestProblem <: ITensorNetworks.AbstractProblem | ||
| data::Vector{Int} | ||
| end | ||
| ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] | ||
| function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) | ||
| kwargs = ITensorNetworks.region_kwargs(iter) | ||
| push!(ITensorNetworks.problem(iter).data, kwargs.val) | ||
| return iter | ||
| end | ||
|
|
||
|
|
||
| mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator | ||
| state::Int | ||
| max::Int | ||
| output::Vector{Int} | ||
| end | ||
|
|
||
| ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 | ||
| Base.length(TI::TestIterator) = TI.max | ||
| ITensorNetworks.state(TI::TestIterator) = TI.state | ||
| function ITensorNetworks.compute!(TI::TestIterator) | ||
| push!(TI.output, ITensorNetworks.state(TI)) | ||
| return TI | ||
| end | ||
|
|
||
| mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator | ||
| parent::TestIterator | ||
| end | ||
|
|
||
| Base.length(SA::SquareAdapter) = length(SA.parent) | ||
| ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) | ||
| ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) | ||
| function ITensorNetworks.compute!(SA::SquareAdapter) | ||
| ITensorNetworks.compute!(SA.parent) | ||
| return last(SA.parent.output)^2 | ||
| end | ||
|
|
||
| end | ||
|
|
||
| @testset "Iterators" begin | ||
|
|
||
| import .TestIteratorUtils | ||
|
|
||
| @testset "`AbstractNetworkIterator` Interface" begin | ||
|
|
||
| @testset "Edge cases" begin | ||
| TI = TestIteratorUtils.TestIterator(1, 1, []) | ||
| cb = [] | ||
| @test islaststep(TI) | ||
| for _ in TI | ||
| @test islaststep(TI) | ||
| push!(cb, state(TI)) | ||
| end | ||
| @test length(cb) == 1 | ||
| @test length(TI.output) == 1 | ||
| @test only(cb) == 1 | ||
|
|
||
| prob = TestIteratorUtils.TestProblem([]) | ||
| @test_throws BoundsError SweepIterator(prob, 0) | ||
| @test_throws BoundsError RegionIterator(prob, [], 1) | ||
| end | ||
|
|
||
| TI = TestIteratorUtils.TestIterator(1, 4, []) | ||
|
|
||
| @test !islaststep((TI)) | ||
|
|
||
| # First iterator should compute only | ||
| rv, st = iterate(TI) | ||
| @test !islaststep((TI)) | ||
| @test !st | ||
| @test rv === TI | ||
| @test length(TI.output) == 1 | ||
| @test only(TI.output) == 1 | ||
| @test state(TI) == 1 | ||
| @test !st | ||
|
|
||
| rv, st = iterate(TI, st) | ||
| @test !islaststep((TI)) | ||
| @test !st | ||
| @test length(TI.output) == 2 | ||
| @test state(TI) == 2 | ||
| @test TI.output == [1, 2] | ||
|
|
||
| increment!(TI) | ||
| @test !islaststep((TI)) | ||
| @test state(TI) == 3 | ||
| @test length(TI.output) == 2 | ||
| @test TI.output == [1, 2] | ||
|
|
||
| compute!(TI) | ||
| @test !islaststep((TI)) | ||
| @test state(TI) == 3 | ||
| @test length(TI.output) == 3 | ||
| @test TI.output == [1, 2, 3] | ||
|
|
||
| # Final Step | ||
| iterate(TI, false) | ||
| @test islaststep((TI)) | ||
| @test state(TI) == 4 | ||
| @test length(TI.output) == 4 | ||
| @test TI.output == [1, 2, 3, 4] | ||
|
|
||
| @test iterate(TI, false) === nothing | ||
|
|
||
| TI = TestIteratorUtils.TestIterator(1, 5, []) | ||
|
|
||
| cb = [] | ||
|
|
||
| for _ in TI | ||
| @test length(cb) == length(TI.output) - 1 | ||
| @test cb == (TI.output)[1:(end - 1)] | ||
| push!(cb, state(TI)) | ||
| @test cb == TI.output | ||
| end | ||
|
|
||
| @test islaststep((TI)) | ||
| @test length(TI.output) == 5 | ||
| @test length(cb) == 5 | ||
| @test cb == TI.output | ||
|
|
||
|
|
||
| TI = TestIteratorUtils.TestIterator(1, 5, []) | ||
| end | ||
|
|
||
| @testset "Adapters" begin | ||
| TI = TestIteratorUtils.TestIterator(1, 5, []) | ||
| SA = TestIteratorUtils.SquareAdapter(TI) | ||
|
|
||
| @testset "Generic" begin | ||
|
|
||
| i = 0 | ||
| for rv in SA | ||
| i += 1 | ||
| @test rv isa Int | ||
| @test rv == i^2 | ||
| @test state(SA) == i | ||
| end | ||
|
|
||
| @test islaststep((SA)) | ||
|
|
||
| TI = TestIteratorUtils.TestIterator(1, 5, []) | ||
| SA = TestIteratorUtils.SquareAdapter(TI) | ||
|
|
||
| SA_c = collect(SA) | ||
|
|
||
| @test SA_c isa Vector | ||
| @test length(SA_c) == 5 | ||
| @test SA_c == [1, 4, 9, 16, 25] | ||
|
|
||
| end | ||
| end | ||
| end | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.