Skip to content

Commit c6122c6

Browse files
authored
Sweeping algorithms based on AlgorithmsInterface.jl (#30)
1 parent 2fe6562 commit c6122c6

16 files changed

+944
-445
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.4"
4+
version = "0.3.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9+
AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d"
910
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
1011
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1112
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations"
3233
[compat]
3334
AbstractTrees = "0.4.5"
3435
Adapt = "4.3"
36+
AlgorithmsInterface = "0.1.0"
3537
BackendSelection = "0.1.6"
3638
Combinatorics = "1"
3739
DataGraphs = "0.2.7"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ ITensorNetworksNext = {path = ".."}
88

99
[compat]
1010
Documenter = "1"
11-
ITensorNetworksNext = "0.2"
11+
ITensorNetworksNext = "0.3"
1212
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
55
ITensorNetworksNext = {path = ".."}
66

77
[compat]
8-
ITensorNetworksNext = "0.2"
8+
ITensorNetworksNext = "0.3"
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
module AlgorithmsInterfaceExtensions
2+
3+
import AlgorithmsInterface as AI
4+
5+
#========================== Patches for AlgorithmsInterface.jl ============================#
6+
7+
abstract type Problem <: AI.Problem end
8+
abstract type Algorithm <: AI.Algorithm end
9+
abstract type State <: AI.State end
10+
11+
function AI.initialize_state!(
12+
problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs...
13+
)
14+
for (k, v) in pairs(kwargs)
15+
setproperty!(state, k, v)
16+
end
17+
state.iteration = iteration
18+
AI.initialize_state!(
19+
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
20+
)
21+
return state
22+
end
23+
24+
function AI.initialize_state(
25+
problem::Problem, algorithm::Algorithm; kwargs...
26+
)
27+
stopping_criterion_state = AI.initialize_state(
28+
problem, algorithm, algorithm.stopping_criterion
29+
)
30+
return DefaultState(; stopping_criterion_state, kwargs...)
31+
end
32+
33+
#============================ DefaultState ================================================#
34+
35+
@kwdef mutable struct DefaultState{
36+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
37+
} <: State
38+
iterate::Iterate
39+
iteration::Int = 0
40+
stopping_criterion_state::StoppingCriterionState
41+
end
42+
43+
#============================ increment! ==================================================#
44+
45+
# Custom version of `increment!` that also takes the problem and algorithm as arguments.
46+
function AI.increment!(problem::Problem, algorithm::Algorithm, state::State)
47+
return AI.increment!(state)
48+
end
49+
50+
#============================ solve! ======================================================#
51+
52+
# Custom version of `solve!` that allows specifying the logger and also overloads
53+
# `increment!` on the problem and algorithm.
54+
function basetypenameof(x)
55+
return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), ".")))
56+
end
57+
default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_)
58+
function default_logging_context_prefix(problem::Problem, algorithm::Algorithm)
59+
return Symbol(
60+
default_logging_context_prefix(problem),
61+
default_logging_context_prefix(algorithm),
62+
)
63+
end
64+
function AI.solve!(
65+
problem::Problem, algorithm::Algorithm, state::State;
66+
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
67+
kwargs...,
68+
)
69+
logger = AI.algorithm_logger()
70+
71+
context_suffixes = [:Start, :PreStep, :PostStep, :Stop]
72+
contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes))
73+
74+
# initialize the state and emit message
75+
AI.initialize_state!(problem, algorithm, state; kwargs...)
76+
AI.emit_message(logger, problem, algorithm, state, contexts[:Start])
77+
78+
# main body of the algorithm
79+
while !AI.is_finished!(problem, algorithm, state)
80+
AI.increment!(problem, algorithm, state)
81+
82+
# logging event between convergence check and algorithm step
83+
AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep])
84+
85+
# algorithm step
86+
AI.step!(problem, algorithm, state; logging_context_prefix)
87+
88+
# logging event between algorithm step and convergence check
89+
AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep])
90+
end
91+
92+
# emit message about finished state
93+
AI.emit_message(logger, problem, algorithm, state, contexts[:Stop])
94+
return state
95+
end
96+
97+
function AI.solve(
98+
problem::Problem, algorithm::Algorithm;
99+
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
100+
kwargs...,
101+
)
102+
state = AI.initialize_state(problem, algorithm; kwargs...)
103+
return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...)
104+
end
105+
106+
#============================ AlgorithmIterator ===========================================#
107+
108+
abstract type AlgorithmIterator end
109+
110+
function algorithm_iterator(
111+
problem::Problem, algorithm::Algorithm, state::State
112+
)
113+
return DefaultAlgorithmIterator(problem, algorithm, state)
114+
end
115+
116+
function AI.is_finished!(iterator::AlgorithmIterator)
117+
return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state)
118+
end
119+
function AI.is_finished(iterator::AlgorithmIterator)
120+
return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state)
121+
end
122+
function AI.increment!(iterator::AlgorithmIterator)
123+
return AI.increment!(iterator.problem, iterator.algorithm, iterator.state)
124+
end
125+
function AI.step!(iterator::AlgorithmIterator)
126+
return AI.step!(iterator.problem, iterator.algorithm, iterator.state)
127+
end
128+
function Base.iterate(iterator::AlgorithmIterator, init = nothing)
129+
AI.is_finished!(iterator) && return nothing
130+
AI.increment!(iterator)
131+
AI.step!(iterator)
132+
return iterator.state, nothing
133+
end
134+
135+
struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
136+
problem::Problem
137+
algorithm::Algorithm
138+
state::State
139+
end
140+
141+
#============================ with_algorithmlogger ========================================#
142+
143+
# Allow passing functions, not just CallbackActions.
144+
@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...)
145+
return AI.with_algorithmlogger(f, args...)
146+
end
147+
@inline function with_algorithmlogger(f, args::Pair{Symbol}...)
148+
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
149+
end
150+
151+
#============================ NestedAlgorithm =============================================#
152+
153+
abstract type NestedAlgorithm <: Algorithm end
154+
155+
function nested_algorithm(f::Function, nalgorithms::Int; kwargs...)
156+
return DefaultNestedAlgorithm(f, nalgorithms; kwargs...)
157+
end
158+
159+
max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)
160+
161+
function get_subproblem(
162+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State
163+
)
164+
subproblem = problem
165+
subalgorithm = algorithm.algorithms[state.iteration]
166+
substate = AI.initialize_state(subproblem, subalgorithm; state.iterate)
167+
return subproblem, subalgorithm, substate
168+
end
169+
170+
function set_substate!(
171+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State
172+
)
173+
state.iterate = substate.iterate
174+
return state
175+
end
176+
177+
function AI.step!(
178+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State;
179+
logging_context_prefix = Symbol()
180+
)
181+
# Get the subproblem, subalgorithm, and substate.
182+
subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state)
183+
184+
# Solve the subproblem with the subalgorithm.
185+
logging_context_prefix = Symbol(
186+
logging_context_prefix, default_logging_context_prefix(subalgorithm)
187+
)
188+
AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix)
189+
190+
# Update the state with the substate.
191+
set_substate!(problem, algorithm, state, substate)
192+
193+
return state
194+
end
195+
196+
#=
197+
DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
198+
199+
An algorithm that consists of running an algorithm at each iteration
200+
from a list of stored algorithms.
201+
=#
202+
@kwdef struct DefaultNestedAlgorithm{
203+
ChildAlgorithm <: Algorithm,
204+
Algorithms <: AbstractVector{ChildAlgorithm},
205+
StoppingCriterion <: AI.StoppingCriterion,
206+
} <: NestedAlgorithm
207+
algorithms::Algorithms
208+
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
209+
end
210+
function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
211+
return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
212+
end
213+
214+
#============================ FlattenedAlgorithm ==========================================#
215+
216+
# Flatten a nested algorithm.
217+
abstract type FlattenedAlgorithm <: Algorithm end
218+
abstract type FlattenedAlgorithmState <: State end
219+
220+
function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...)
221+
return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...)
222+
end
223+
224+
function AI.initialize_state(
225+
problem::Problem, algorithm::FlattenedAlgorithm; kwargs...
226+
)
227+
stopping_criterion_state = AI.initialize_state(
228+
problem, algorithm, algorithm.stopping_criterion
229+
)
230+
return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
231+
end
232+
function AI.increment!(
233+
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
234+
)
235+
# Increment the total iteration count.
236+
state.iteration += 1
237+
# TODO: Use `is_finished!` instead?
238+
if state.child_iteration max_iterations(algorithm.algorithms[state.parent_iteration])
239+
# We're on the last iteration of the child algorithm, so move to the next
240+
# child algorithm.
241+
state.parent_iteration += 1
242+
state.child_iteration = 1
243+
else
244+
# Iterate the child algorithm.
245+
state.child_iteration += 1
246+
end
247+
return state
248+
end
249+
function AI.step!(
250+
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState;
251+
logging_context_prefix = Symbol()
252+
)
253+
algorithm_sweep = algorithm.algorithms[state.parent_iteration]
254+
state_sweep = AI.initialize_state(
255+
problem, algorithm_sweep;
256+
state.iterate, iteration = state.child_iteration
257+
)
258+
AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix)
259+
state.iterate = state_sweep.iterate
260+
return state
261+
end
262+
263+
@kwdef struct DefaultFlattenedAlgorithm{
264+
ChildAlgorithm <: Algorithm,
265+
Algorithms <: AbstractVector{ChildAlgorithm},
266+
StoppingCriterion <: AI.StoppingCriterion,
267+
} <: FlattenedAlgorithm
268+
algorithms::Algorithms
269+
stopping_criterion::StoppingCriterion =
270+
AI.StopAfterIteration(sum(max_iterations, algorithms))
271+
end
272+
function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
273+
return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
274+
end
275+
276+
@kwdef mutable struct DefaultFlattenedAlgorithmState{
277+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
278+
} <: FlattenedAlgorithmState
279+
iterate::Iterate
280+
iteration::Int = 0
281+
parent_iteration::Int = 1
282+
child_iteration::Int = 0
283+
stopping_criterion_state::StoppingCriterionState
284+
end
285+
286+
#============================ NonIterativeAlgorithm =======================================#
287+
288+
# Algorithm that only performs a single step.
289+
abstract type NonIterativeAlgorithm <: Algorithm end
290+
abstract type NonIterativeAlgorithmState <: State end
291+
292+
function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...)
293+
return DefaultNonIterativeAlgorithmState(; kwargs...)
294+
end
295+
function AI.solve!(
296+
problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs...
297+
)
298+
return throw(MethodError(AI.solve!, (problem, algorithm, state)))
299+
end
300+
301+
@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <:
302+
NonIterativeAlgorithmState
303+
iterate::Iterate
304+
end
305+
306+
end

src/ITensorNetworksNext.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module ITensorNetworksNext
22

3+
include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl")
34
include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl")
45
include("abstracttensornetwork.jl")
56
include("tensornetwork.jl")
67
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
78
include("contract_network.jl")
8-
include("abstract_problem.jl")
9-
include("iterators.jl")
10-
include("adapters.jl")
9+
include("sweeping/utils.jl")
10+
include("sweeping/eigenproblem.jl")
1111

1212
end

src/abstract_problem.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)