Skip to content

Commit f78c5e6

Browse files
committed
[WIP] Sweeping algorithms based on AlgorithmsInterface.jl
1 parent 5b53e00 commit f78c5e6

File tree

12 files changed

+437
-444
lines changed

12 files changed

+437
-444
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
3-
version = "0.2.3"
3+
version = "0.3.0"
44
authors = ["ITensor developers <[email protected]> and contributors"]
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9+
AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d"
910
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
1011
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1112
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations"
3233
[compat]
3334
AbstractTrees = "0.4.5"
3435
Adapt = "4.3"
36+
AlgorithmsInterface = "0.1.0"
3537
BackendSelection = "0.1.6"
3638
Combinatorics = "1"
3739
DataGraphs = "0.2.7"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
66
[compat]
77
Documenter = "1"
88
Literate = "2"
9-
ITensorNetworksNext = "0.2"
9+
ITensorNetworksNext = "0.3"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33

44
[compat]
5-
ITensorNetworksNext = "0.2"
5+
ITensorNetworksNext = "0.3"
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
module AlgorithmsInterfaceExtensions
2+
3+
import AlgorithmsInterface as AI
4+
5+
#========================== Patches for AlgorithmsInterface.jl ============================#
6+
7+
abstract type Problem <: AI.Problem end
8+
abstract type Algorithm <: AI.Algorithm end
9+
abstract type State <: AI.State end
10+
11+
function AI.initialize_state!(
12+
problem::Problem, algorithm::Algorithm, state::State; iterate = nothing
13+
)
14+
!isnothing(iterate) && (state.iterate = iterate)
15+
AI.initialize_state!(
16+
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
17+
)
18+
return state
19+
end
20+
21+
function AI.initialize_state(
22+
problem::Problem, algorithm::Algorithm; kwargs...
23+
)
24+
stopping_criterion_state = AI.initialize_state(
25+
problem, algorithm, algorithm.stopping_criterion
26+
)
27+
return DefaultState(; stopping_criterion_state, kwargs...)
28+
end
29+
30+
#============================ DefaultState ================================================#
31+
32+
@kwdef mutable struct DefaultState{
33+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
34+
} <: State
35+
iterate::Iterate
36+
iteration::Int = 0
37+
stopping_criterion_state::StoppingCriterionState
38+
end
39+
40+
#============================ increment! ==================================================#
41+
42+
# Custom version of `increment!` that also takes the problem and algorithm as arguments.
43+
function AI.increment!(problem::Problem, algorithm::Algorithm, state::State)
44+
return AI.increment!(state)
45+
end
46+
47+
#============================ solve! ======================================================#
48+
49+
# Custom version of `solve!` that allows specifying the logger and also overloads
50+
# `increment!` on the problem and algorithm.
51+
function basetypenameof(x)
52+
return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), ".")))
53+
end
54+
default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_)
55+
function default_logging_context_prefix(problem::Problem, algorithm::Algorithm)
56+
return Symbol(
57+
default_logging_context_prefix(problem),
58+
default_logging_context_prefix(algorithm),
59+
)
60+
end
61+
function AI.solve!(
62+
problem::Problem, algorithm::Algorithm, state::State;
63+
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
64+
kwargs...,
65+
)
66+
logger = AI.algorithm_logger()
67+
68+
context_suffixes = [:Start, :PreStep, :PostStep, :Stop]
69+
contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes))
70+
71+
# initialize the state and emit message
72+
AI.initialize_state!(problem, algorithm, state; kwargs...)
73+
AI.emit_message(logger, problem, algorithm, state, contexts[:Start])
74+
75+
# main body of the algorithm
76+
while !AI.is_finished!(problem, algorithm, state)
77+
AI.increment!(problem, algorithm, state)
78+
79+
# logging event between convergence check and algorithm step
80+
AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep])
81+
82+
# algorithm step
83+
AI.step!(problem, algorithm, state; logging_context_prefix)
84+
85+
# logging event between algorithm step and convergence check
86+
AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep])
87+
end
88+
89+
# emit message about finished state
90+
AI.emit_message(logger, problem, algorithm, state, contexts[:Stop])
91+
return state
92+
end
93+
94+
function AI.solve(
95+
problem::Problem, algorithm::Algorithm;
96+
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
97+
kwargs...,
98+
)
99+
state = AI.initialize_state(problem, algorithm; kwargs...)
100+
return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...)
101+
end
102+
103+
#============================ AlgorithmIterator ===========================================#
104+
105+
abstract type AlgorithmIterator end
106+
107+
function algorithm_iterator(
108+
problem::Problem, algorithm::Algorithm, state::State
109+
)
110+
return DefaultAlgorithmIterator(problem, algorithm, state)
111+
end
112+
113+
function AI.is_finished!(iterator::AlgorithmIterator)
114+
return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state)
115+
end
116+
function AI.is_finished(iterator::AlgorithmIterator)
117+
return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state)
118+
end
119+
function AI.increment!(iterator::AlgorithmIterator)
120+
return AI.increment!(iterator.problem, iterator.algorithm, iterator.state)
121+
end
122+
function AI.step!(iterator::AlgorithmIterator)
123+
return AI.step!(iterator.problem, iterator.algorithm, iterator.state)
124+
end
125+
function Base.iterate(iterator::AlgorithmIterator, init = nothing)
126+
AI.is_finished!(iterator) && return nothing
127+
AI.increment!(iterator)
128+
AI.step!(iterator)
129+
return iterator.state, nothing
130+
end
131+
132+
struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
133+
problem::Problem
134+
algorithm::Algorithm
135+
state::State
136+
end
137+
138+
#============================ with_algorithmlogger ========================================#
139+
140+
# Allow passing functions, not just CallbackActions.
141+
@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...)
142+
return AI.with_algorithmlogger(f, args...)
143+
end
144+
@inline function with_algorithmlogger(f, args::Pair{Symbol}...)
145+
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
146+
end
147+
148+
#============================ NestedAlgorithm =============================================#
149+
150+
abstract type NestedAlgorithm <: Algorithm end
151+
152+
function nested_algorithm(f::Function, nalgorithms::Int; kwargs...)
153+
return DefaultNestedAlgorithm(f, nalgorithms; kwargs...)
154+
end
155+
156+
max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)
157+
158+
function get_subproblem(
159+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State
160+
)
161+
subproblem = problem
162+
subalgorithm = algorithm.algorithms[state.iteration]
163+
substate = AI.initialize_state(subproblem, subalgorithm; state.iterate)
164+
return subproblem, subalgorithm, substate
165+
end
166+
167+
function set_substate!(
168+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State
169+
)
170+
state.iterate = substate.iterate
171+
return state
172+
end
173+
174+
function AI.step!(
175+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State;
176+
logging_context_prefix = Symbol()
177+
)
178+
# Get the subproblem, subalgorithm, and substate.
179+
subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state)
180+
181+
# Solve the subproblem with the subalgorithm.
182+
logging_context_prefix = Symbol(
183+
logging_context_prefix, default_logging_context_prefix(subalgorithm)
184+
)
185+
AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix)
186+
187+
# Update the state with the substate.
188+
set_substate!(problem, algorithm, state, substate)
189+
190+
return state
191+
end
192+
193+
#=
194+
DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
195+
196+
An algorithm that consists of running an algorithm at each iteration
197+
from a list of stored algorithms.
198+
=#
199+
@kwdef struct DefaultNestedAlgorithm{
200+
Algorithms <: AbstractVector{<:Algorithm},
201+
StoppingCriterion <: AI.StoppingCriterion,
202+
} <: NestedAlgorithm
203+
algorithms::Algorithms
204+
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
205+
end
206+
function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
207+
return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
208+
end
209+
210+
#============================ FlattenedAlgorithm ==========================================#
211+
212+
# Flatten a nested algorithm.
213+
abstract type FlattenedAlgorithm <: Algorithm end
214+
abstract type FlattenedAlgorithmState <: State end
215+
216+
function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...)
217+
return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...)
218+
end
219+
220+
function AI.initialize_state(
221+
problem::Problem, algorithm::FlattenedAlgorithm; kwargs...
222+
)
223+
stopping_criterion_state = AI.initialize_state(
224+
problem, algorithm, algorithm.stopping_criterion
225+
)
226+
return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
227+
end
228+
function AI.increment!(
229+
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
230+
)
231+
# Increment the total iteration count.
232+
state.iteration += 1
233+
# TODO: Use `is_finished!` instead?
234+
if state.child_iteration max_iterations(algorithm.algorithms[state.parent_iteration])
235+
# We're on the last iteration of the child algorithm, so move to the next
236+
# child algorithm.
237+
state.parent_iteration += 1
238+
state.child_iteration = 1
239+
else
240+
# Iterate the child algorithm.
241+
state.child_iteration += 1
242+
end
243+
return state
244+
end
245+
function AI.step!(
246+
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState;
247+
logging_context_prefix = Symbol()
248+
)
249+
algorithm_sweep = algorithm.algorithms[state.parent_iteration]
250+
state_sweep = AI.initialize_state(
251+
problem, algorithm_sweep;
252+
state.iterate, iteration = state.child_iteration
253+
)
254+
AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix)
255+
state.iterate = state_sweep.iterate
256+
return state
257+
end
258+
259+
@kwdef struct DefaultFlattenedAlgorithm{
260+
Algorithms <: AbstractVector{<:Algorithm},
261+
StoppingCriterion <: AI.StoppingCriterion,
262+
} <: FlattenedAlgorithm
263+
algorithms::Algorithms
264+
stopping_criterion::StoppingCriterion =
265+
AI.StopAfterIteration(sum(max_iterations, algorithms))
266+
end
267+
function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
268+
return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
269+
end
270+
271+
@kwdef mutable struct DefaultFlattenedAlgorithmState{
272+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
273+
} <: FlattenedAlgorithmState
274+
iterate::Iterate
275+
iteration::Int = 0
276+
parent_iteration::Int = 1
277+
child_iteration::Int = 0
278+
stopping_criterion_state::StoppingCriterionState
279+
end
280+
281+
end

src/ITensorNetworksNext.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module ITensorNetworksNext
22

3+
include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl")
34
include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl")
45
include("abstracttensornetwork.jl")
56
include("tensornetwork.jl")
67
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
78
include("contract_network.jl")
8-
include("abstract_problem.jl")
9-
include("iterators.jl")
10-
include("adapters.jl")
9+
include("sweeping.jl")
10+
include("eigenproblem.jl")
1111

1212
end

src/abstract_problem.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/adapters.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)