Skip to content

Commit b92a19a

Browse files
committed
Add progress callbacks
1 parent e9e237c commit b92a19a

File tree

8 files changed

+350
-29
lines changed

8 files changed

+350
-29
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.2.6"
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
10+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1011
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1112
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1213
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

docs/Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.8.2"
3+
julia_version = "1.8.3"
44
manifest_format = "2.0"
55
project_hash = "ce26edf36ffc3f7c29725c9c72c38a467df9bb2b"
66

@@ -149,10 +149,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
149149
version = "0.3.2"
150150

151151
[[deps.ClimaTimeSteppers]]
152-
deps = ["CUDA", "ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
152+
deps = ["CUDA", "ClimaComms", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
153153
path = ".."
154154
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
155-
version = "0.2.5"
155+
version = "0.2.6"
156156

157157
[[deps.CloseOpenIntervals]]
158158
deps = ["ArrayInterface", "Static"]

docs/src/callbacks.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,12 @@ EveryXWallTimeSeconds
1919
EveryXSimulationTime
2020
EveryXSimulationSteps
2121
```
22+
23+
# Progress Callbacks
24+
```@meta
25+
CurrentModule = ClimaTimeSteppers
26+
```
27+
```@docs
28+
BasicProgressCallback
29+
TerminalProgressCallback
30+
```

perf/Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.8.2"
3+
julia_version = "1.8.3"
44
manifest_format = "2.0"
55
project_hash = "f4b3eeecc9753e545581990f89e8fdf9df989ae0"
66

@@ -144,10 +144,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
144144
version = "0.3.2"
145145

146146
[[deps.ClimaTimeSteppers]]
147-
deps = ["CUDA", "ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
147+
deps = ["CUDA", "ClimaComms", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
148148
path = ".."
149149
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
150-
version = "0.2.5"
150+
version = "0.2.6"
151151

152152
[[deps.CloseOpenIntervals]]
153153
deps = ["ArrayInterface", "Static"]

src/ClimaTimeSteppers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ realview(x::Union{Array, SArray, MArray}) = x
5757
realview(x::CuArray) = x
5858

5959

60-
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
60+
import SciMLBase, DiffEqBase, DiffEqCallbacks, Krylov, LinearAlgebra, Dates
6161

6262
include("sparse_containers.jl")
6363
include("functions.jl")
@@ -69,6 +69,7 @@ end
6969

7070
SciMLBase.allowscomplex(alg::DistributedODEAlgorithm) = true
7171
include("integrators.jl")
72+
include("progress_callbacks.jl")
7273

7374
include("solvers/update_signal_handler.jl")
7475
include("solvers/convergence_condition.jl")

src/integrators.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ function DiffEqBase.__init(
7777
save_func = (u, t, integrator) -> copy(u),
7878
callback = nothing,
7979
advance_to_tstop = false,
80-
dtchangeable = true, # custom kwarg
81-
stepstop = -1, # custom kwarg
80+
progress = false,
81+
dtchangeable = true, # custom kwarg
82+
stepstop = -1, # custom kwarg
83+
progress_kwargs = (;), # custom kwarg
8284
kwargs...,
8385
)
8486
(; u0, p) = prob
@@ -98,7 +100,23 @@ function DiffEqBase.__init(
98100
DiffEqCallbacks.SavedValues(sol.t, sol.u),
99101
save_everystep,
100102
)
101-
callback = DiffEqBase.CallbackSet(callback, saving_callback)
103+
progress_callback = if progress == false
104+
nothing
105+
elseif progress == true
106+
if stdout isa Base.TTY
107+
TerminalProgressCallback(typeof(t0); progress_kwargs...)
108+
else
109+
BasicProgressCallback(; progress_kwargs...)
110+
end
111+
elseif progress == :terminal
112+
TerminalProgressCallback(typeof(t0); progress_kwargs...)
113+
elseif progress == :basic
114+
BasicProgressCallback(; progress_kwargs...)
115+
else
116+
error("progress must be true, false, :terminal, or :basic")
117+
end
118+
callback =
119+
DiffEqBase.CallbackSet(callback, saving_callback, progress_callback)
102120
isempty(callback.continuous_callbacks) ||
103121
error("Continuous callbacks are not supported")
104122

@@ -226,14 +244,14 @@ function __step!(integrator)
226244
tdir(integrator) * _dt
227245
step_u!(integrator)
228246

229-
# increment t by dt, rounding to the first tstop if that is roughly
230-
# equivalent up to machine precision; the specific bound of 100 * eps...
231-
# is taken from OrdinaryDiffEq.jl
247+
# increment t by dt, rounding to the first tstop if that is equivalent up to
248+
# round-off error; they are considered equivalent if they differ by less
249+
# than 100eps(t), which is roughly identical to what OrdinaryDiffEq.jl does:
250+
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/129c76bcc35fd9801f36ce090035c1b750a842ec/src/integrators/integrator_utils.jl#L229-L236
232251
t_plus_dt = integrator.t + integrator.dt
233-
t_unit = oneunit(integrator.t)
234-
max_t_error = 100 * eps(float(integrator.t / t_unit)) * t_unit
235252
integrator.t =
236-
!isempty(tstops) && abs(first(tstops) - t_plus_dt) < max_t_error ?
253+
!isempty(tstops) &&
254+
abs(first(tstops) - t_plus_dt) < 100 * eps(integrator.t) ?
237255
first(tstops) : t_plus_dt
238256

239257
# apply callbacks

src/progress_callbacks.jl

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
export BasicProgressCallback, TerminalProgressCallback
2+
3+
# '█' = \blockfull
4+
# '▐' = \blockrighthalf
5+
# '▌' = \blocklefthalf
6+
7+
"""
8+
BasicProgressCallback(; kwargs...)
9+
10+
Creates a `DiffEqBase.DiscreteCallback` that prints out a progress bar as the
11+
integrator runs.
12+
13+
# Keywords
14+
- `io::IO = stdout`: output stream to which the progress bar is printed
15+
- `min_update_time::Real = 0.5`: minimum delay (in seconds) between prints
16+
- `bar_title::String = "Progress: "`: string that gets printed to the left of
17+
the progress bar
18+
- `bar_length::Integer = 100`: number of characters in the progress bar
19+
- `show_tens_intervals::Bool = true`: whether to add ticks above the progress
20+
bar that indicate the positions of 0%, 10%, 20%, ..., 100%; if this is
21+
`false`, only the ticks at 0% and 100% are shown
22+
"""
23+
function BasicProgressCallback(;
24+
io = stdout,
25+
min_update_time = 0.5,
26+
bar_title = "Progress: ",
27+
bar_length = 100,
28+
show_tens_intervals = true,
29+
)
30+
if show_tens_intervals
31+
if bar_length % 10 != 0
32+
new_bar_length = bar_length - bar_length % 10
33+
@warn "bar_length must be a multiple of 10 if show_tens_intervals \
34+
is true; decreasing the given value of $bar_length to \
35+
$new_bar_length"
36+
bar_length = new_bar_length
37+
end
38+
interval_length = bar_length ÷ 10
39+
interval_labels =
40+
prod(i -> '0' * lpad(i, interval_length - 1), 1:10) * "0 (%)"
41+
interval_ticks = ('' * ' '^(interval_length - 2) * '')^10
42+
bar_header =
43+
' '^(length(bar_title)) * interval_labels * '\n' *
44+
' '^(length(bar_title)) * interval_ticks * '\n' * bar_title
45+
else
46+
edge_ticks = '' * ' '^(bar_length - 2) * ''
47+
bar_header = ' '^(length(bar_title)) * edge_ticks * '\n' * bar_title
48+
end
49+
time = Ref{Float64}()
50+
prev_time = Ref{Float64}()
51+
prev_filled_bar_length = Ref{Int}()
52+
function initialize(cb, u, t, integrator)
53+
print(io, bar_header)
54+
prev_time[] = time_ns() / 1e9
55+
prev_filled_bar_length[] = 0
56+
end
57+
function condition(u, t, integrator)
58+
time[] = time_ns() / 1e9
59+
return time[] >= prev_time[] + min_update_time ||
60+
isempty(integrator.tstops)
61+
end
62+
function affect!(integrator)
63+
t0 = integrator.sol.prob.tspan[1]
64+
tf = maximum(integrator.tstops.valtree)
65+
bar_fill_ratio = (integrator.t - t0) / (tf - t0)
66+
filled_bar_length = floor(Int, bar_length * bar_fill_ratio)
67+
added_bar_length = filled_bar_length - prev_filled_bar_length[]
68+
added_bar_length > 0 && print(io, ''^added_bar_length)
69+
prev_time[] = time[]
70+
prev_filled_bar_length[] = filled_bar_length
71+
end
72+
function finalize(cb, u, t, integrator)
73+
added_bar_length = bar_length - prev_filled_bar_length[]
74+
println(io, ''^added_bar_length)
75+
end
76+
return DiffEqBase.DiscreteCallback(condition, affect!; initialize, finalize)
77+
end
78+
79+
"""
80+
TerminalProgressCallback([tType]; kwargs...)
81+
82+
Creates a `DiffEqBase.DiscreteCallback` that prints out a progress bar as the
83+
integrator runs, along with an estimate of the time remaining until the
84+
integrator is finished, and also an optional user-specified message (e.g.,
85+
diagnostic information about the integrator's state, or a Unicode plot).
86+
87+
This progress bar is designed to be overwritten on every update, so it should
88+
only be printed to a UNIX terminal (i.e., a terminal that supports clearing the
89+
previous line by printing the control sequence `"\\e[1A\\r\\e[0K"`).
90+
91+
The time remaining is estimated by computing an exponential moving average of
92+
the integrator's speed (seconds of real time that elapse per unit of integrator
93+
time) and multiplying this average speed by the remaining integrator time; the
94+
first integrator step is not included in the average so as to avoid biasing the
95+
estimate with compilation time.
96+
97+
# Arguments
98+
- `tType::Type = Float64`: type of `integrator.t`
99+
100+
# Keywords
101+
- `io::IO = stdout`: terminal output stream to which the progress bar is printed
102+
- `min_update_time::Real = 0.5`: minimum delay (in seconds) between prints
103+
- `bar_title::String = "Progress: "`: string that gets printed to the left of
104+
the progress bar
105+
- `relative_bar_length::Real = 0.7`: ratio between the number of characters in
106+
the progress bar and the width of the terminal (`displaysize(io)[2]`)
107+
- `eta_title::String = "Time Remaining: "`: string that gets printed to the left
108+
of the time remaining
109+
- `new_speed_weight::Real = 0.1`: weight of the new speed in the formula for
110+
updating the exponential moving average —
111+
`average := new_speed_weight * new_speed + (1 - new_speed_weight) * average`
112+
- `custom_message::Union{Nothing, Function} = nothing`: an optional function of
113+
the form `(integrator, terminal_width) -> String` that generates a custom
114+
message whenever the progress bar is updated; the generated string can have
115+
multiple lines, and it is recommended to limit the length of each line to be
116+
no bigger than the terminal width in order to improve readability
117+
- `clear_when_finished::Bool = true`: whether to clear away the progress bar
118+
(and any other information that was printed) when the integrator is finished
119+
"""
120+
function TerminalProgressCallback(
121+
::Type{tType} = Float64;
122+
io = stdout,
123+
min_update_time = 0.5,
124+
bar_title = "Progress: ",
125+
relative_bar_length = 0.7,
126+
eta_title = "Time Remaining: ",
127+
new_speed_weight = 0.1,
128+
custom_message = nothing,
129+
clear_when_finished = true,
130+
) where {tType}
131+
clear_line = "\e[1A\r\e[0K"
132+
time = Ref{Float64}()
133+
prev_time = Ref{Float64}()
134+
prev_progress_string = Ref{String}()
135+
is_first_step = Ref{Bool}()
136+
is_first_speed = Ref{Bool}()
137+
prev_t = Ref{tType}()
138+
average_speed = Ref{typeof(one(Float64) / one(tType))}()
139+
function clear_prev_progress_string(terminal_width)
140+
prev_progress_string_height = sum(
141+
line -> max(1, cld(length(line), terminal_width)),
142+
eachsplit(prev_progress_string[], '\n'),
143+
)
144+
return clear_line^prev_progress_string_height
145+
end
146+
function initialize(cb, u, t, integrator)
147+
terminal_width = displaysize(io)[2]
148+
bar_length = floor(Int, terminal_width * relative_bar_length)
149+
progress_string =
150+
bar_title * '' * ' '^bar_length * "▌ 0.0%\n" * eta_title
151+
if !isnothing(custom_message)
152+
progress_string *= '\n' * custom_message(integrator, terminal_width)
153+
end
154+
println(io, progress_string)
155+
prev_time[] = time_ns() / 1e9
156+
prev_progress_string[] = progress_string
157+
is_first_step[] = true
158+
end
159+
function condition(u, t, integrator)
160+
time[] = time_ns() / 1e9
161+
return time[] >= prev_time[] + min_update_time ||
162+
isempty(integrator.tstops)
163+
end
164+
function affect!(integrator)
165+
terminal_width = displaysize(io)[2]
166+
clear_string = clear_prev_progress_string(terminal_width)
167+
bar_length = floor(Int, terminal_width * relative_bar_length)
168+
t0 = integrator.sol.prob.tspan[1]
169+
tf = maximum(integrator.tstops.valtree)
170+
bar_fill_ratio = (integrator.t - t0) / (tf - t0)
171+
filled_bar_length = floor(Int, bar_length * bar_fill_ratio)
172+
percent_string = ' ' * string(floor(1000 * bar_fill_ratio) / 10) * '%'
173+
if is_first_step[]
174+
eta_string = "..."
175+
is_first_step[] = false
176+
is_first_speed[] = true
177+
else
178+
new_speed = (time[] - prev_time[]) / (integrator.t - prev_t[])
179+
if is_first_speed[]
180+
average_speed[] = new_speed
181+
is_first_speed[] = false
182+
else
183+
average_speed[] =
184+
new_speed_weight * new_speed +
185+
(1 - new_speed_weight) * average_speed[]
186+
end
187+
eta = round(Int, average_speed[] * (tf - integrator.t))
188+
eta_string = eta == 0 ? "0 seconds" :
189+
string(Dates.canonicalize(Dates.Second(eta)))
190+
end
191+
progress_string =
192+
bar_title * '' * ''^filled_bar_length *
193+
' '^(bar_length - filled_bar_length) * '' * percent_string * '\n' *
194+
eta_title * eta_string
195+
if !isnothing(custom_message)
196+
progress_string *= '\n' * custom_message(integrator, terminal_width)
197+
end
198+
println(io, clear_string * progress_string)
199+
prev_time[] = time[]
200+
prev_progress_string[] = progress_string
201+
prev_t[] = integrator.t
202+
end
203+
function finalize(cb, u, t, integrator)
204+
terminal_width = displaysize(io)[2]
205+
clear_string = clear_prev_progress_string(terminal_width)
206+
if clear_when_finished
207+
print(io, clear_string)
208+
else
209+
bar_length = floor(Int, terminal_width * relative_bar_length)
210+
progress_string *=
211+
bar_title * '' * ''^bar_length * "▌ 100.0%\n" * eta_title
212+
if !isnothing(custom_message)
213+
progress_string *=
214+
'\n' * custom_message(integrator, terminal_width)
215+
end
216+
println(io, clear_string * progress_string)
217+
end
218+
end
219+
return DiffEqBase.DiscreteCallback(condition, affect!; initialize, finalize)
220+
end

0 commit comments

Comments
 (0)