Skip to content

Commit 4ce453e

Browse files
author
Jack Dunham
committed
Add tests for EachRegion and eachregion wrapper functions
1 parent fed9137 commit 4ce453e

File tree

1 file changed

+65
-15
lines changed

1 file changed

+65
-15
lines changed

test/solvers/test_iterators.jl

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
using Test: @test, @testset
2-
using ITensorNetworks: laststep, state, increment!, compute!
2+
using ITensorNetworks: SweepIterator, laststep, state, increment!, compute!, eachregion
33

44
module TestIteratorUtils
55

66
using ITensorNetworks
77

8+
struct TestProblem <: ITensorNetworks.AbstractProblem
9+
data::Vector{Int}
10+
end
11+
ITensorNetworks.region_plan(::TestProblem) = [:a => (; val=1), :b => (; val=2)]
12+
function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem})
13+
kwargs = ITensorNetworks.current_region_kwargs(iter)
14+
push!(ITensorNetworks.problem(iter).data, kwargs.val)
15+
return iter
16+
end
17+
18+
819
mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator
920
state::Int
1021
max::Int
@@ -35,7 +46,7 @@ end
3546

3647
@testset "Iterators" begin
3748

38-
using .TestIteratorUtils: TestIterator, SquareAdapter
49+
using .TestIteratorUtils: TestIterator, SquareAdapter, TestProblem
3950

4051
@testset "`AbstractNetworkIterator` Interface" begin
4152
TI = TestIterator(1, 4, [])
@@ -104,23 +115,62 @@ end
104115
TI = TestIterator(1, 5, [])
105116
SA = SquareAdapter(TI)
106117

107-
i = 0
108-
for rv in SA
109-
i += 1
110-
@test rv isa Int
111-
@test rv == i^2
112-
@test state(SA) == i
118+
@testset "Generic" begin
119+
120+
i = 0
121+
for rv in SA
122+
i += 1
123+
@test rv isa Int
124+
@test rv == i^2
125+
@test state(SA) == i
126+
end
127+
128+
@test laststep((SA))
129+
130+
TI = TestIterator(1, 5, [])
131+
SA = SquareAdapter(TI)
132+
133+
SA_c = collect(SA)
134+
135+
@test SA_c isa Vector
136+
@test length(SA_c) == 5
137+
@test SA_c == [1, 4, 9, 16, 25]
138+
113139
end
114140

115-
@test laststep((SA))
141+
@testset "EachRegion" begin
142+
prob = TestProblem([])
143+
prob_region = TestProblem([])
116144

117-
TI = TestIterator(1, 5, [])
118-
SA = SquareAdapter(TI)
145+
SI = SweepIterator(prob, 5)
146+
SI_region = SweepIterator(prob_region, 5)
147+
148+
callback = []
149+
callback_region = []
150+
151+
let i = 1
152+
for _ in SI
153+
push!(callback, i)
154+
i += 1
155+
end
156+
end
157+
158+
@test length(callback) == 5
119159

120-
SA_c = collect(SA)
160+
let i = 1
161+
for _ in eachregion(SI_region)
162+
push!(callback_region, i)
163+
i += 1
164+
end
165+
end
121166

122-
@test SA_c isa Vector
123-
@test length(SA_c) == 5
124-
@test SA_c == [1, 4, 9, 16, 25]
167+
@test length(callback_region) == 10
168+
169+
@test prob.data == prob_region.data
170+
171+
@test prob.data[1:2:end] == fill(1, 5)
172+
@test prob.data[2:2:end] == fill(2, 5)
173+
174+
end
125175
end
126176
end

0 commit comments

Comments
 (0)