Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 98 additions & 13 deletions src/OutputWriters/windowed_time_average.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Oceananigans.Diagnostics: AbstractDiagnostic
using Oceananigans.OutputWriters: fetch_output
using Oceananigans.Utils: AbstractSchedule, prettytime
using Oceananigans.TimeSteppers: Clock
using Dates: Period

import Oceananigans: run_diagnostic!
import Oceananigans.Utils: TimeInterval, SpecifiedTimes
Expand Down Expand Up @@ -91,30 +92,67 @@ function (sch::AveragedTimeInterval)(model)
return scheduled
end
initialize_schedule!(sch::AveragedTimeInterval, clock) = nothing
outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window
outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window
end_of_window(sch::AveragedTimeInterval, clock) = clock.time >= next_actuation_time(sch)

TimeInterval(sch::AveragedTimeInterval) = TimeInterval(sch.interval)
Base.copy(sch::AveragedTimeInterval) = AveragedTimeInterval(sch.interval, window=sch.window, stride=sch.stride)



"""
mutable struct AveragedSpecifiedTimes <: AbstractSchedule

A schedule for averaging over windows that precede SpecifiedTimes.
"""
mutable struct AveragedSpecifiedTimes <: AbstractSchedule
mutable struct AveragedSpecifiedTimes{W} <: AbstractSchedule
specified_times :: SpecifiedTimes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since SpecifiedTimes is not a concrete type, this should also be a type parameter

window :: Float64
window :: W
stride :: Int
collecting :: Bool
end

const VaryingWindowAveragedSpecifiedTimes = AveragedSpecifiedTimes{<:Vector}

AveragedSpecifiedTimes(specified_times::SpecifiedTimes; window, stride=1) =
AveragedSpecifiedTimes(specified_times, window, stride, false)

AveragedSpecifiedTimes(times; kw...) = AveragedSpecifiedTimes(SpecifiedTimes(times); kw...)
AveragedSpecifiedTimes(times; window, kw...) = AveragedSpecifiedTimes(times, window; kw...)

function determine_epsilon(eltype)
if eltype <: AbstractFloat
return eps(eltype)
elseif eltype <: Period
return Second(0)
else
return 0
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function determine_epsilon(eltype)
if eltype <: AbstractFloat
return eps(eltype)
elseif eltype <: Period
return Second(0)
else
return 0
end
end
determine_epsilon(eltype) = 0
determine_epsilon(eltype::AbstractType) = eps(eltype)
determine_epsilon(::Period) = Second(0)

the name of the function is slightly problematic

Copy link
Collaborator Author

@xkykai xkykai Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you meant

determine_epsilon(eltype) = 0
determine_epsilon(::Type{T}) where T <: AbstractFloat = eps(T)
determine_epsilon(::Period) = Second(0)

right?

I agree the name is terrible, but I haven't figured out a better alternative...


function AveragedSpecifiedTimes(times, window::Vector; kw...)
length(window) == length(times) || throw(ArgumentError("When providing a vector of windows, its length $(length(window)) must match the number of specified times $(length(times))."))
perm = sortperm(times)
sorted_times = times[perm]
sorted_window = window[perm]
time_diff = diff(vcat(0, sorted_times))

epsilon = determine_epsilon(eltype(window))
any(time_diff .- sorted_window .< -epsilon) && throw(ArgumentError("Averaging windows overlap. Ensure that for each specified time tᵢ, tᵢ - windowᵢ ≥ tᵢ₋₁."))

return AveragedSpecifiedTimes(SpecifiedTimes(sorted_times); window=sorted_window, kw...)
end

function AveragedSpecifiedTimes(times, window::Union{<:Number, <:Period}; kw...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function AveragedSpecifiedTimes(times, window::Union{<:Number, <:Period}; kw...)
function AveragedSpecifiedTimes(times, window; kw...)

can this be the fallback?

sorted_times = sort(times)
time_diff = diff(vcat(0, sorted_times))

epsilon = determine_epsilon(typeof(window))

any(time_diff .- window .< -epsilon) && throw(ArgumentError("Averaging window $window is too large and causes overlapping windows. Ensure that for each specified time tᵢ, tᵢ - window ≥ tᵢ₋₁."))

return AveragedSpecifiedTimes(SpecifiedTimes(times); window, kw...)
end

get_next_window(schedule::VaryingWindowAveragedSpecifiedTimes) = schedule.window[schedule.specified_times.previous_actuation + 1]
get_next_window(schedule::AveragedSpecifiedTimes) = schedule.window

function (schedule::AveragedSpecifiedTimes)(model)
time = model.clock.time
Expand All @@ -123,7 +161,7 @@ function (schedule::AveragedSpecifiedTimes)(model)
next > length(schedule.specified_times.times) && return false

next_time = schedule.specified_times.times[next]
window = schedule.window
window = get_next_window(schedule)

schedule.collecting || time >= next_time - window
end
Expand All @@ -134,7 +172,8 @@ function outside_window(schedule::AveragedSpecifiedTimes, clock)
next = schedule.specified_times.previous_actuation + 1
next > length(schedule.specified_times.times) && return true
next_time = schedule.specified_times.times[next]
return clock.time < next_time - schedule.window
window = get_next_window(schedule)
return clock.time < next_time - window
end

function end_of_window(schedule::AveragedSpecifiedTimes, clock)
Expand All @@ -144,6 +183,11 @@ function end_of_window(schedule::AveragedSpecifiedTimes, clock)
return clock.time >= next_time
end

TimeInterval(sch::AveragedSpecifiedTimes) = TimeInterval(sch.specified_times.times)
Base.copy(sch::AveragedSpecifiedTimes) = AveragedSpecifiedTimes(copy(sch.specified_times); window=sch.window, stride=sch.stride)

next_actuation_time(sch::AveragedSpecifiedTimes) = Oceananigans.Utils.next_actuation_time(sch.specified_times)

#####
##### WindowedTimeAverage
#####
Expand All @@ -168,7 +212,7 @@ stride(wta::SpecifiedWindowedTimeAverage) = wta.schedule.stride
WindowedTimeAverage(operand, model=nothing; schedule)

Returns an object for computing running averages of `operand` over `schedule.window` and
recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval`.
recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval` or `AveragedSpecifiedTimes`.
During the collection period, averages are computed every `schedule.stride` iteration.

`operand` may be a `Oceananigans.Field` or a function that returns an array or scalar.
Expand Down Expand Up @@ -261,6 +305,40 @@ function advance_time_average!(wta::WindowedTimeAverage, model)
return nothing
end

function advance_time_average!(wta::SpecifiedWindowedTimeAverage, model)

unscheduled = model.clock.iteration == 0 || outside_window(wta.schedule, model.clock)
if !(unscheduled)
if !(wta.schedule.collecting)
# Zero out result to begin new accumulation window
wta.result .= 0

# Begin collecting window-averaged increments
wta.schedule.collecting = true

wta.window_start_time = next_actuation_time(wta.schedule) - get_next_window(wta.schedule)
wta.previous_collection_time = wta.window_start_time
wta.window_start_iteration = model.clock.iteration - 1
# @info "t $(prettytime(model.clock.time)), next actuation time: $(prettytime(next_actuation_time(wta.schedule))), window $(prettytime(wta.schedule.window))"
end

if end_of_window(wta.schedule, model.clock)
accumulate_result!(wta, model)
# Save averaging start time and the initial data collection time
wta.schedule.collecting = false
wta.schedule.specified_times.previous_actuation += 1

elseif mod(model.clock.iteration - wta.window_start_iteration, stride(wta)) == 0
accumulate_result!(wta, model)
else
# Off stride, so do nothing.
end

end
return nothing
end


# So it can be used as a Diagnostic
run_diagnostic!(wta::WindowedTimeAverage, model) = advance_time_average!(wta, model)

Expand All @@ -271,8 +349,14 @@ Base.summary(schedule::AveragedTimeInterval) = string("AveragedTimeInterval(",
"stride=", schedule.stride, ", ",
"interval=", prettytime(schedule.interval), ")")

Base.summary(schedule::AveragedSpecifiedTimes) = string("AveragedSpecifiedTimes(",
"window=", prettytime(schedule.window), ", ",
"stride=", schedule.stride, ", ",
"times=", schedule.specified_times, ")")

show_averaging_schedule(schedule) = ""
show_averaging_schedule(schedule::AveragedTimeInterval) = string(" averaged on ", summary(schedule))
show_averaging_schedule(schedule::AveragedSpecifiedTimes) = string(" averaged on ", summary(schedule))

output_averaging_schedule(output::WindowedTimeAverage) = output.schedule

Expand All @@ -282,6 +366,8 @@ output_averaging_schedule(output::WindowedTimeAverage) = output.schedule

time_average_outputs(schedule, outputs, model) = schedule, outputs # fallback

const AveragedTimeSchedule = Union{AveragedTimeInterval, AveragedSpecifiedTimes}

"""
time_average_outputs(schedule::AveragedTimeInterval, outputs, model, field_slicer)

Expand All @@ -290,17 +376,16 @@ Wrap each `output` in a `WindowedTimeAverage` on the time-averaged `schedule` an
Returns the `TimeInterval` associated with `schedule` and a `NamedTuple` or `Dict` of the wrapped
outputs.
"""
function time_average_outputs(schedule::AveragedTimeInterval, outputs::Dict, model)
function time_average_outputs(schedule::AveragedTimeSchedule, outputs::Dict, model)
averaged_outputs = Dict(name => WindowedTimeAverage(output, model; schedule=copy(schedule))
for (name, output) in outputs)

return TimeInterval(schedule), averaged_outputs
end

function time_average_outputs(schedule::AveragedTimeInterval, outputs::NamedTuple, model)
function time_average_outputs(schedule::AveragedTimeSchedule, outputs::NamedTuple, model)
averaged_outputs = NamedTuple(name => WindowedTimeAverage(outputs[name], model; schedule=copy(schedule))
for name in keys(outputs))

return TimeInterval(schedule), averaged_outputs
end

end
1 change: 1 addition & 0 deletions src/Utils/prettytime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ function prettytimeunits(t, longform=true)
end

prettytime(dt::AbstractTime) = "$dt"
prettytime(t::Array) = prettytime.(t)
2 changes: 2 additions & 0 deletions src/Utils/schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ function specified_times_str(st)
return string(str, "]")
end

Base.copy(st::SpecifiedTimes) = SpecifiedTimes(copy(st.times), st.previous_actuation)

#####
##### ConsecutiveIterations
#####
Expand Down
8 changes: 8 additions & 0 deletions src/Utils/times_and_datetimes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,19 @@ end
@inline add_time_interval(base::AbstractTime, interval::Number, count=1) = base + seconds_to_nanosecond(interval * count)
@inline add_time_interval(base::AbstractTime, interval::Period, count=1) = base + count * interval

@inline add_time_interval(base::Number, interval::Array{<:Number}, count=1) = interval[count]

function period_type(interval::Number)
FT = Oceananigans.defaults.FloatType
return FT
end

function period_type(interval::Array{<:Number})
FT = Oceananigans.defaults.FloatType
return Array{FT, 1}
end

period_type(interval::Dates.Period) = typeof(interval)
time_type(interval::Number) = typeof(interval)
time_type(interval::Dates.Period) = Dates.DateTime
time_type(interval::Array{<:Number}) = eltype(interval)
Loading