Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.2.6"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down
6 changes: 3 additions & 3 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.8.2"
julia_version = "1.8.3"
manifest_format = "2.0"
project_hash = "ce26edf36ffc3f7c29725c9c72c38a467df9bb2b"

Expand Down Expand Up @@ -149,10 +149,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.3.2"

[[deps.ClimaTimeSteppers]]
deps = ["CUDA", "ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
deps = ["CUDA", "ClimaComms", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
path = ".."
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.2.5"
version = "0.2.6"

[[deps.CloseOpenIntervals]]
deps = ["ArrayInterface", "Static"]
Expand Down
9 changes: 9 additions & 0 deletions docs/src/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,12 @@ EveryXWallTimeSeconds
EveryXSimulationTime
EveryXSimulationSteps
```

# Progress Callbacks
```@meta
CurrentModule = ClimaTimeSteppers
```
```@docs
BasicProgressCallback
TerminalProgressCallback
```
6 changes: 3 additions & 3 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.8.2"
julia_version = "1.8.3"
manifest_format = "2.0"
project_hash = "f4b3eeecc9753e545581990f89e8fdf9df989ae0"

Expand Down Expand Up @@ -144,10 +144,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.3.2"

[[deps.ClimaTimeSteppers]]
deps = ["CUDA", "ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
deps = ["CUDA", "ClimaComms", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
path = ".."
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.2.5"
version = "0.2.6"

[[deps.CloseOpenIntervals]]
deps = ["ArrayInterface", "Static"]
Expand Down
3 changes: 2 additions & 1 deletion src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ realview(x::Union{Array, SArray, MArray}) = x
realview(x::CuArray) = x


import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
import SciMLBase, DiffEqBase, DiffEqCallbacks, Krylov, LinearAlgebra, Dates

include("sparse_containers.jl")
include("functions.jl")
Expand All @@ -69,6 +69,7 @@ end

SciMLBase.allowscomplex(alg::DistributedODEAlgorithm) = true
include("integrators.jl")
include("progress_callbacks.jl")

include("solvers/update_signal_handler.jl")
include("solvers/convergence_condition.jl")
Expand Down
36 changes: 27 additions & 9 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ function DiffEqBase.__init(
save_func = (u, t, integrator) -> copy(u),
callback = nothing,
advance_to_tstop = false,
dtchangeable = true, # custom kwarg
stepstop = -1, # custom kwarg
progress = false,
dtchangeable = true, # custom kwarg
stepstop = -1, # custom kwarg
progress_kwargs = (;), # custom kwarg
kwargs...,
)
(; u0, p) = prob
Expand All @@ -98,7 +100,23 @@ function DiffEqBase.__init(
DiffEqCallbacks.SavedValues(sol.t, sol.u),
save_everystep,
)
callback = DiffEqBase.CallbackSet(callback, saving_callback)
progress_callback = if progress == false
nothing
elseif progress == true
if stdout isa Base.TTY
TerminalProgressCallback(typeof(t0); progress_kwargs...)
else
BasicProgressCallback(; progress_kwargs...)
end
elseif progress == :terminal
TerminalProgressCallback(typeof(t0); progress_kwargs...)
elseif progress == :basic
BasicProgressCallback(; progress_kwargs...)
else
error("progress must be true, false, :terminal, or :basic")
end
callback =
DiffEqBase.CallbackSet(callback, saving_callback, progress_callback)
isempty(callback.continuous_callbacks) ||
error("Continuous callbacks are not supported")

Expand Down Expand Up @@ -226,14 +244,14 @@ function __step!(integrator)
tdir(integrator) * _dt
step_u!(integrator)

# increment t by dt, rounding to the first tstop if that is roughly
# equivalent up to machine precision; the specific bound of 100 * eps...
# is taken from OrdinaryDiffEq.jl
# increment t by dt, rounding to the first tstop if that is equivalent up to
# round-off error; they are considered equivalent if they differ by less
# than 100eps(t), which is roughly identical to what OrdinaryDiffEq.jl does:
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/129c76bcc35fd9801f36ce090035c1b750a842ec/src/integrators/integrator_utils.jl#L229-L236
t_plus_dt = integrator.t + integrator.dt
t_unit = oneunit(integrator.t)
max_t_error = 100 * eps(float(integrator.t / t_unit)) * t_unit
integrator.t =
!isempty(tstops) && abs(first(tstops) - t_plus_dt) < max_t_error ?
!isempty(tstops) &&
abs(first(tstops) - t_plus_dt) < 100 * eps(integrator.t) ?
first(tstops) : t_plus_dt

# apply callbacks
Expand Down
220 changes: 220 additions & 0 deletions src/progress_callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
export BasicProgressCallback, TerminalProgressCallback

# '█' = \blockfull
# '▐' = \blockrighthalf
# '▌' = \blocklefthalf

"""
BasicProgressCallback(; kwargs...)

Creates a `DiffEqBase.DiscreteCallback` that prints out a progress bar as the
integrator runs.

# Keywords
- `io::IO = stdout`: output stream to which the progress bar is printed
- `min_update_time::Real = 0.5`: minimum delay (in seconds) between prints
- `bar_title::String = "Progress: "`: string that gets printed to the left of
the progress bar
- `bar_length::Integer = 100`: number of characters in the progress bar
- `show_tens_intervals::Bool = true`: whether to add ticks above the progress
bar that indicate the positions of 0%, 10%, 20%, ..., 100%; if this is
`false`, only the ticks at 0% and 100% are shown
"""
function BasicProgressCallback(;
io = stdout,
min_update_time = 0.5,
bar_title = "Progress: ",
bar_length = 100,
show_tens_intervals = true,
)
if show_tens_intervals
if bar_length % 10 != 0
new_bar_length = bar_length - bar_length % 10
@warn "bar_length must be a multiple of 10 if show_tens_intervals \
is true; decreasing the given value of $bar_length to \
$new_bar_length"
bar_length = new_bar_length
end
interval_length = bar_length ÷ 10
interval_labels =
prod(i -> '0' * lpad(i, interval_length - 1), 1:10) * "0 (%)"
interval_ticks = ('▌' * ' '^(interval_length - 2) * '▐')^10
bar_header =
' '^(length(bar_title)) * interval_labels * '\n' *
' '^(length(bar_title)) * interval_ticks * '\n' * bar_title
else
edge_ticks = '▌' * ' '^(bar_length - 2) * '▐'
bar_header = ' '^(length(bar_title)) * edge_ticks * '\n' * bar_title
end
time = Ref{Float64}()
prev_time = Ref{Float64}()
prev_filled_bar_length = Ref{Int}()
function initialize(cb, u, t, integrator)
print(io, bar_header)
prev_time[] = time_ns() / 1e9
prev_filled_bar_length[] = 0
end
function condition(u, t, integrator)
time[] = time_ns() / 1e9
return time[] >= prev_time[] + min_update_time ||
isempty(integrator.tstops)
end
function affect!(integrator)
t0 = integrator.sol.prob.tspan[1]
tf = maximum(integrator.tstops.valtree)
bar_fill_ratio = (integrator.t - t0) / (tf - t0)
filled_bar_length = floor(Int, bar_length * bar_fill_ratio)
added_bar_length = filled_bar_length - prev_filled_bar_length[]
added_bar_length > 0 && print(io, '█'^added_bar_length)
prev_time[] = time[]
prev_filled_bar_length[] = filled_bar_length
end
function finalize(cb, u, t, integrator)
added_bar_length = bar_length - prev_filled_bar_length[]
println(io, '█'^added_bar_length)
end
return DiffEqBase.DiscreteCallback(condition, affect!; initialize, finalize)
end

"""
TerminalProgressCallback([tType]; kwargs...)

Creates a `DiffEqBase.DiscreteCallback` that prints out a progress bar as the
integrator runs, along with an estimate of the time remaining until the
integrator is finished, and also an optional user-specified message (e.g.,
diagnostic information about the integrator's state, or a Unicode plot).

This progress bar is designed to be overwritten on every update, so it should
only be printed to a UNIX terminal (i.e., a terminal that supports clearing the
previous line by printing the control sequence `"\\e[1A\\r\\e[0K"`).

The time remaining is estimated by computing an exponential moving average of
the integrator's speed (seconds of real time that elapse per unit of integrator
time) and multiplying this average speed by the remaining integrator time; the
first integrator step is not included in the average so as to avoid biasing the
estimate with compilation time.

# Arguments
- `tType::Type = Float64`: type of `integrator.t`

# Keywords
- `io::IO = stdout`: terminal output stream to which the progress bar is printed
- `min_update_time::Real = 0.5`: minimum delay (in seconds) between prints
- `bar_title::String = "Progress: "`: string that gets printed to the left of
the progress bar
- `relative_bar_length::Real = 0.7`: ratio between the number of characters in
the progress bar and the width of the terminal (`displaysize(io)[2]`)
- `eta_title::String = "Time Remaining: "`: string that gets printed to the left
of the time remaining
- `new_speed_weight::Real = 0.1`: weight of the new speed in the formula for
updating the exponential moving average —
`average := new_speed_weight * new_speed + (1 - new_speed_weight) * average`
- `custom_message::Union{Nothing, Function} = nothing`: an optional function of
the form `(integrator, terminal_width) -> String` that generates a custom
message whenever the progress bar is updated; the generated string can have
multiple lines, and it is recommended to limit the length of each line to be
no bigger than the terminal width in order to improve readability
- `clear_when_finished::Bool = true`: whether to clear away the progress bar
(and any other information that was printed) when the integrator is finished
"""
function TerminalProgressCallback(
::Type{tType} = Float64;
io = stdout,
min_update_time = 0.5,
bar_title = "Progress: ",
relative_bar_length = 0.7,
eta_title = "Time Remaining: ",
new_speed_weight = 0.1,
custom_message = nothing,
clear_when_finished = true,
) where {tType}
clear_line = "\e[1A\r\e[0K"
time = Ref{Float64}()
prev_time = Ref{Float64}()
prev_progress_string = Ref{String}()
is_first_step = Ref{Bool}()
is_first_speed = Ref{Bool}()
prev_t = Ref{tType}()
average_speed = Ref{typeof(one(Float64) / one(tType))}()
function clear_prev_progress_string(terminal_width)
prev_progress_string_height = sum(
line -> max(1, cld(length(line), terminal_width)),
eachsplit(prev_progress_string[], '\n'),
)
return clear_line^prev_progress_string_height
end
function initialize(cb, u, t, integrator)
terminal_width = displaysize(io)[2]
bar_length = floor(Int, terminal_width * relative_bar_length)
progress_string =
bar_title * '▐' * ' '^bar_length * "▌ 0.0%\n" * eta_title
if !isnothing(custom_message)
progress_string *= '\n' * custom_message(integrator, terminal_width)
end
println(io, progress_string)
prev_time[] = time_ns() / 1e9
prev_progress_string[] = progress_string
is_first_step[] = true
end
function condition(u, t, integrator)
time[] = time_ns() / 1e9
return time[] >= prev_time[] + min_update_time ||
isempty(integrator.tstops)
end
function affect!(integrator)
terminal_width = displaysize(io)[2]
clear_string = clear_prev_progress_string(terminal_width)
bar_length = floor(Int, terminal_width * relative_bar_length)
t0 = integrator.sol.prob.tspan[1]
tf = maximum(integrator.tstops.valtree)
bar_fill_ratio = (integrator.t - t0) / (tf - t0)
filled_bar_length = floor(Int, bar_length * bar_fill_ratio)
percent_string = ' ' * string(floor(1000 * bar_fill_ratio) / 10) * '%'
if is_first_step[]
eta_string = "..."
is_first_step[] = false
is_first_speed[] = true
else
new_speed = (time[] - prev_time[]) / (integrator.t - prev_t[])
if is_first_speed[]
average_speed[] = new_speed
is_first_speed[] = false
else
average_speed[] =
new_speed_weight * new_speed +
(1 - new_speed_weight) * average_speed[]
end
eta = round(Int, average_speed[] * (tf - integrator.t))
eta_string = eta == 0 ? "0 seconds" :
string(Dates.canonicalize(Dates.Second(eta)))
end
progress_string =
bar_title * '▐' * '█'^filled_bar_length *
' '^(bar_length - filled_bar_length) * '▌' * percent_string * '\n' *
eta_title * eta_string
if !isnothing(custom_message)
progress_string *= '\n' * custom_message(integrator, terminal_width)
end
println(io, clear_string * progress_string)
prev_time[] = time[]
prev_progress_string[] = progress_string
prev_t[] = integrator.t
end
function finalize(cb, u, t, integrator)
terminal_width = displaysize(io)[2]
clear_string = clear_prev_progress_string(terminal_width)
if clear_when_finished
print(io, clear_string)
else
bar_length = floor(Int, terminal_width * relative_bar_length)
progress_string *=
bar_title * '▐' * '█'^bar_length * "▌ 100.0%\n" * eta_title
if !isnothing(custom_message)
progress_string *=
'\n' * custom_message(integrator, terminal_width)
end
println(io, clear_string * progress_string)
end
end
return DiffEqBase.DiscreteCallback(condition, affect!; initialize, finalize)
end
Loading