Skip to content

Commit 7e9ba0f

Browse files
author
Jack Dunham
committed
Add adapters.jl code
1 parent 790fd68 commit 7e9ba0f

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ include("tensornetwork.jl")
66
include("contract_network.jl")
77
include("abstract_problem.jl")
88
include("iterators.jl")
9+
include("adapters.jl")
910

1011
end

src/adapters.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
3+
4+
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
5+
process. This allows one to manually call a custom `compute!` or insert their own code it in
6+
the loop body in place of `compute!`.
7+
"""
8+
struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator
9+
parent::S
10+
end
11+
12+
islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
13+
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
14+
compute!(adapter::IncrementOnly) = adapter
15+
16+
IncrementOnly(adapter::IncrementOnly) = adapter
17+
18+
"""
19+
struct EachRegion{SweepIterator} <: AbstractNetworkIterator
20+
21+
Adapter that flattens each region iterator in the parent sweep iterator into a single
22+
iterator.
23+
"""
24+
struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator
25+
parent::SI
26+
end
27+
28+
# In keeping with Julia convention.
29+
eachregion(iter::SweepIterator) = EachRegion(iter)
30+
31+
# Essential definitions
32+
function islaststep(adapter::EachRegion)
33+
region_iter = region_iterator(adapter.parent)
34+
return islaststep(adapter.parent) && islaststep(region_iter)
35+
end
36+
function increment!(adapter::EachRegion)
37+
region_iter = region_iterator(adapter.parent)
38+
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
39+
return adapter
40+
end
41+
function compute!(adapter::EachRegion)
42+
region_iter = region_iterator(adapter.parent)
43+
compute!(region_iter)
44+
return adapter
45+
end

test/test_iterators.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test: @test, @testset, @test_throws
22
import ITensorNetworksNext as ITensorNetworks
3-
using .ITensorNetworks: RegionIterator, SweepIterator, compute!, increment!, islaststep, state
3+
using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion
44

55
module TestIteratorUtils
66

@@ -157,5 +157,65 @@ end
157157
@test SA_c == [1, 4, 9, 16, 25]
158158

159159
end
160+
161+
@testset "IncrementOnly" begin
162+
TI = TestIteratorUtils.TestIterator(1, 5, [])
163+
NI = IncrementOnly(TI)
164+
165+
NI_c = []
166+
167+
for _ in IncrementOnly(TI)
168+
push!(NI_c, state(TI))
169+
end
170+
171+
@test length(NI_c) == 5
172+
@test isempty(TI.output)
173+
end
174+
175+
@testset "EachRegion" begin
176+
prob = TestIteratorUtils.TestProblem([])
177+
prob_region = TestIteratorUtils.TestProblem([])
178+
179+
SI = SweepIterator(prob, 5)
180+
SI_region = SweepIterator(prob_region, 5)
181+
182+
callback = []
183+
callback_region = []
184+
185+
let i = 1
186+
for _ in SI
187+
push!(callback, i)
188+
i += 1
189+
end
190+
end
191+
192+
@test length(callback) == 5
193+
194+
let i = 1
195+
for _ in eachregion(SI_region)
196+
push!(callback_region, i)
197+
i += 1
198+
end
199+
end
200+
201+
@test length(callback_region) == 10
202+
203+
@test prob.data == prob_region.data
204+
205+
@test prob.data[1:2:end] == fill(1, 5)
206+
@test prob.data[2:2:end] == fill(2, 5)
207+
208+
209+
let i = 1, prob = TestIteratorUtils.TestProblem([])
210+
SI = SweepIterator(prob, 1)
211+
cb = []
212+
for _ in eachregion(SI)
213+
push!(cb, i)
214+
i += 1
215+
end
216+
@test length(cb) == 2
217+
end
218+
219+
end
160220
end
161221
end

0 commit comments

Comments
 (0)