|
1 | 1 | using Test: @test, @testset |
2 | | -using ITensorNetworks: laststep, state, increment!, compute! |
| 2 | +using ITensorNetworks: SweepIterator, laststep, state, increment!, compute!, eachregion |
3 | 3 |
|
4 | 4 | module TestIteratorUtils |
5 | 5 |
|
6 | 6 | using ITensorNetworks |
7 | 7 |
|
| 8 | +struct TestProblem <: ITensorNetworks.AbstractProblem |
| 9 | + data::Vector{Int} |
| 10 | +end |
| 11 | +ITensorNetworks.region_plan(::TestProblem) = [:a => (; val=1), :b => (; val=2)] |
| 12 | +function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) |
| 13 | + kwargs = ITensorNetworks.current_region_kwargs(iter) |
| 14 | + push!(ITensorNetworks.problem(iter).data, kwargs.val) |
| 15 | + return iter |
| 16 | +end |
| 17 | + |
| 18 | + |
8 | 19 | mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator |
9 | 20 | state::Int |
10 | 21 | max::Int |
|
35 | 46 |
|
36 | 47 | @testset "Iterators" begin |
37 | 48 |
|
38 | | - using .TestIteratorUtils: TestIterator, SquareAdapter |
| 49 | + using .TestIteratorUtils: TestIterator, SquareAdapter, TestProblem |
39 | 50 |
|
40 | 51 | @testset "`AbstractNetworkIterator` Interface" begin |
41 | 52 | TI = TestIterator(1, 4, []) |
@@ -104,23 +115,62 @@ end |
104 | 115 | TI = TestIterator(1, 5, []) |
105 | 116 | SA = SquareAdapter(TI) |
106 | 117 |
|
107 | | - i = 0 |
108 | | - for rv in SA |
109 | | - i += 1 |
110 | | - @test rv isa Int |
111 | | - @test rv == i^2 |
112 | | - @test state(SA) == i |
| 118 | + @testset "Generic" begin |
| 119 | + |
| 120 | + i = 0 |
| 121 | + for rv in SA |
| 122 | + i += 1 |
| 123 | + @test rv isa Int |
| 124 | + @test rv == i^2 |
| 125 | + @test state(SA) == i |
| 126 | + end |
| 127 | + |
| 128 | + @test laststep((SA)) |
| 129 | + |
| 130 | + TI = TestIterator(1, 5, []) |
| 131 | + SA = SquareAdapter(TI) |
| 132 | + |
| 133 | + SA_c = collect(SA) |
| 134 | + |
| 135 | + @test SA_c isa Vector |
| 136 | + @test length(SA_c) == 5 |
| 137 | + @test SA_c == [1, 4, 9, 16, 25] |
| 138 | + |
113 | 139 | end |
114 | 140 |
|
115 | | - @test laststep((SA)) |
| 141 | + @testset "EachRegion" begin |
| 142 | + prob = TestProblem([]) |
| 143 | + prob_region = TestProblem([]) |
116 | 144 |
|
117 | | - TI = TestIterator(1, 5, []) |
118 | | - SA = SquareAdapter(TI) |
| 145 | + SI = SweepIterator(prob, 5) |
| 146 | + SI_region = SweepIterator(prob_region, 5) |
| 147 | + |
| 148 | + callback = [] |
| 149 | + callback_region = [] |
| 150 | + |
| 151 | + let i = 1 |
| 152 | + for _ in SI |
| 153 | + push!(callback, i) |
| 154 | + i += 1 |
| 155 | + end |
| 156 | + end |
| 157 | + |
| 158 | + @test length(callback) == 5 |
119 | 159 |
|
120 | | - SA_c = collect(SA) |
| 160 | + let i = 1 |
| 161 | + for _ in eachregion(SI_region) |
| 162 | + push!(callback_region, i) |
| 163 | + i += 1 |
| 164 | + end |
| 165 | + end |
121 | 166 |
|
122 | | - @test SA_c isa Vector |
123 | | - @test length(SA_c) == 5 |
124 | | - @test SA_c == [1, 4, 9, 16, 25] |
| 167 | + @test length(callback_region) == 10 |
| 168 | + |
| 169 | + @test prob.data == prob_region.data |
| 170 | + |
| 171 | + @test prob.data[1:2:end] == fill(1, 5) |
| 172 | + @test prob.data[2:2:end] == fill(2, 5) |
| 173 | + |
| 174 | + end |
125 | 175 | end |
126 | 176 | end |
0 commit comments