Skip to content

Commit c0659ae

Browse files
committed
Define NonIterativeAlgorithm, make Region a subtype of that, starting testing
1 parent 5a85467 commit c0659ae

File tree

4 files changed

+64
-57
lines changed

4 files changed

+64
-57
lines changed

src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ abstract type Algorithm <: AI.Algorithm end
99
abstract type State <: AI.State end
1010

1111
function AI.initialize_state!(
12-
problem::Problem, algorithm::Algorithm, state::State; iterate = nothing
12+
problem::Problem, algorithm::Algorithm, state::State; kwargs...
1313
)
14-
!isnothing(iterate) && (state.iterate = iterate)
1514
AI.initialize_state!(
1615
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
1716
)
@@ -278,4 +277,23 @@ end
278277
stopping_criterion_state::StoppingCriterionState
279278
end
280279

280+
#============================ NonIterativeAlgorithm =======================================#
281+
282+
# Algorithm that only performs a single step.
283+
abstract type NonIterativeAlgorithm <: Algorithm end
284+
285+
function Base.getproperty(algorithm::NonIterativeAlgorithm, name::Symbol)
286+
if name :stopping_criterion
287+
return AI.StopAfterIteration(1)
288+
else
289+
return getfield(algorithm, name)
290+
end
291+
end
292+
293+
abstract type NonIterativeAlgorithmState <: State end
294+
295+
mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: NonIterativeAlgorithmState
296+
iterate::Iterate
297+
end
298+
281299
end

src/sweeping/eigenproblem.jl

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -58,36 +58,6 @@ struct EigenProblem{Operator} <: AIE.Problem
5858
operator::Operator
5959
end
6060

61-
function AI.step!(problem::EigenProblem, algorithm::Sweep, state::AI.State; kwargs...)
62-
iterate = solve_region!!(
63-
problem, algorithm.region_algorithms[state.iteration](state.iterate), state.iterate
64-
)
65-
state.iterate = iterate
66-
return state
67-
end
68-
69-
# extract!, update!, insert! for the region.
70-
function solve_region!!(problem::EigenProblem, algorithm::RegionAlgorithm, state)
71-
operator = problem.operator
72-
region = algorithm.region
73-
region_kwargs = algorithm.kwargs
74-
75-
#=
76-
# Reduce the `operator` and state `x` onto the region `region`,
77-
# and call `eigsolve` on the reduced operator and state using the
78-
# keyword arguments determined from `region_kwargs`.
79-
operator_region = reduced_operator(operator, x, region)
80-
x_region = reduced_state(x, region)
81-
x_region′ = eigsolve(operator_region, x_region; region_kwargs.update...)
82-
x′ = insert(x, region, x_region′; region_kwargs.insert...)
83-
state.state = x′
84-
=#
85-
86-
# Dummy update for demonstration purposes.
87-
state′ = "region = $region" *
88-
", update_kwargs = $(region_kwargs.update)" *
89-
", insert_kwargs = $(region_kwargs.insert)"
90-
state = [state; [state′]]
91-
92-
return state
61+
function AI.step!(problem::EigenProblem, algorithm::Region, state::AIE.State; kwargs...)
62+
return error("Region step for EigenProblem not implemented.")
9363
end

src/sweeping/sweeping.jl

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ import AlgorithmsInterface as AI
22
import .AlgorithmsInterfaceExtensions as AIE
33

44
@kwdef struct Sweeping{
5-
Algorithms <: AbstractVector{<:AI.Algorithm},
6-
StoppingCriterion <: AI.StoppingCriterion,
5+
Algorithms <: AbstractVector, StoppingCriterion <: AI.StoppingCriterion,
76
} <: AIE.NestedAlgorithm
87
algorithms::Algorithms
98
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
@@ -24,31 +23,18 @@ which is converted into a function that always returns the same keyword argument
2423
for an region.
2524
=#
2625
@kwdef struct Sweep{
27-
RegionAlgorithms <: AbstractVector, StoppingCriterion <: AI.StoppingCriterion,
28-
} <: AIE.Algorithm
29-
region_algorithms::RegionAlgorithms
30-
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(region_algorithms))
26+
Algorithms <: AbstractVector, StoppingCriterion <: AI.StoppingCriterion,
27+
} <: AIE.NestedAlgorithm
28+
algorithms::Algorithms
29+
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
3130
end
3231
function Sweep(f, nalgorithms::Int; kwargs...)
33-
region_algorithms = to_region_algorithm.(f.(1:nalgorithms))
34-
return Sweep(; region_algorithms, kwargs...)
32+
return Sweep(; algorithms = f.(1:nalgorithms), kwargs...)
3533
end
36-
to_region_algorithm(algorithm::Function) = algorithm
37-
to_region_algorithm(algorithm) = Returns(region_algorithm(algorithm))
38-
39-
AIE.max_iterations(algorithm::Sweep) = length(algorithm.algorithms)
4034

41-
abstract type RegionAlgorithm end
42-
region_algorithm(algorithm::RegionAlgorithm) = algorithm
43-
region_algorithm(algorithm::NamedTuple) = Region(; algorithm...)
44-
45-
struct Region{R, Kwargs <: NamedTuple} <: RegionAlgorithm
35+
struct Region{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm
4636
region::R
4737
kwargs::Kwargs
4838
end
49-
function Region(; region, kwargs...)
50-
return Region(region, (; kwargs...))
51-
end
52-
function Region(region; kwargs...)
53-
return Region(region, (; kwargs...))
54-
end
39+
Region(; region, kwargs...) = Region(region, (; kwargs...))
40+
Region(region; kwargs...) = Region(region, (; kwargs...))

test/test_sweeping.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import AlgorithmsInterface as AI
2+
using ITensorNetworksNext: Region, Sweep, Sweeping
3+
import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE
4+
using Test: @test, @testset
5+
6+
struct TestProblem <: AIE.Problem
7+
end
8+
9+
function AI.step!(problem::TestProblem, algorithm::Region, state::AIE.State; kwargs...)
10+
state.iterate = algorithm.region
11+
return state
12+
end
13+
14+
@testset "Sweeping" begin
15+
@testset "Region" begin
16+
algorithm = Region("region"; foo = 1, bar = 2)
17+
@test algorithm isa AIE.NonIterativeAlgorithm
18+
@test algorithm isa AIE.Algorithm
19+
@test algorithm isa AI.Algorithm
20+
@test algorithm.region == "region"
21+
@test algorithm.kwargs == (; foo = 1, bar = 2)
22+
@test Region(; region = "region", foo = 1, bar = 2) == algorithm
23+
24+
problem = TestProblem()
25+
iterate = ""
26+
state = AI.solve(problem, algorithm; iterate)
27+
@test state.iterate == "region"
28+
end
29+
@testset "Sweep" begin
30+
end
31+
@testset "Sweeping" begin
32+
end
33+
end

0 commit comments

Comments
 (0)