Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
5fcbcfa
first commit for array averagedspecifiedtimes
xkykai Oct 21, 2025
5457bc5
allow window to be arrays
xkykai Oct 21, 2025
96872aa
Enhance AveragedSpecifiedTimes to support vector windows and add vali…
xkykai Oct 22, 2025
61d29ac
Refactor AveragedSpecifiedTimes to ensure non-overlapping windows and…
xkykai Oct 22, 2025
c20c9b0
Fix AveragedSpecifiedTimes to use sorted times for averaging windows
xkykai Oct 22, 2025
35c41a5
Remove debug logging from AveragedSpecifiedTimes function
xkykai Oct 22, 2025
6d16825
update commented out method
xkykai Oct 22, 2025
6ea30f5
Refactor AveragedSpecifiedTimes to support varying window types and i…
xkykai Oct 22, 2025
56e6e55
Refactor AveragedTimeInterval struct to support generic types for int…
xkykai Oct 22, 2025
b3f1541
make averagedspecifiedtimes concrete
xkykai Oct 22, 2025
ed15fd6
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Oct 22, 2025
0efdbc8
change dispatch method
xkykai Oct 24, 2025
3ab8e0e
change to fallback
xkykai Oct 24, 2025
82ed44a
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Nov 3, 2025
5e769ba
refactor: update AveragedSpecifiedTimes struct to allow more flexible…
xkykai Nov 3, 2025
29ecc03
Merge branch 'xk/enable-array-averagedtimeinterval' of https://github…
xkykai Nov 4, 2025
c2132c3
md files removed
xkykai Nov 4, 2025
9dc02bd
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Nov 4, 2025
307f8b5
Revert "md files removed"
xkykai Nov 4, 2025
a71018d
add compatiblity for AveragedSpecifiedTimes with datetime, remove che…
xkykai Nov 5, 2025
d62eb35
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Nov 18, 2025
177058a
Add validation for overlapping averaging windows in AveragedSpecified…
xkykai Nov 19, 2025
ca2bebc
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 5, 2025
1e59bfa
Add tests for AveragedSpecifiedTimes functionality and error handling
xkykai Dec 5, 2025
b341e65
Add test for AveragedSpecifiedTimes functionality
xkykai Dec 5, 2025
03ec093
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 8, 2025
da1ea76
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 8, 2025
050db9f
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 8, 2025
52b77a0
Enhance WindowedTimeAverage function to handle default time based on …
xkykai Dec 11, 2025
146932e
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 11, 2025
841208b
Update src/OutputWriters/windowed_time_average.jl
xkykai Dec 11, 2025
9a0c5da
Add return statements to validate_windows functions for improved clarity
xkykai Dec 11, 2025
30ad6bc
change time tol to respect find_time_index
xkykai Dec 11, 2025
cc9e6a8
Update src/Utils/prettytime.jl
xkykai Dec 11, 2025
b3e2d4a
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 19, 2025
83fa6e6
move AveragedSpecifiedTimes to new file
xkykai Dec 25, 2025
d80a852
validating window at runtime and adjust overlapping windows check at …
xkykai Dec 30, 2025
152df23
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Dec 30, 2025
fa6f4ae
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Jan 2, 2026
b41405e
fix merge conflict and redundant methods
xkykai Jan 2, 2026
cf03b10
expand averagedspecifiedtimes docstring
xkykai Jan 2, 2026
c3d27b4
fix averagedtimeinterval test
xkykai Jan 2, 2026
4252ca5
remove commented code
xkykai Jan 2, 2026
f835d04
fix docstring test
xkykai Jan 2, 2026
b8d8c52
Merge branch 'main' into xk/enable-array-averagedtimeinterval
navidcy Jan 3, 2026
0ce572c
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Jan 5, 2026
3605423
grid is now positional?
xkykai Jan 5, 2026
3c6e04e
fix hydrostaticfreesurefacemodel syntac
xkykai Jan 5, 2026
f2f2109
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Jan 6, 2026
7f09170
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Jan 13, 2026
ec86625
Add prognostic state functions for AveragedSpecifiedTimes
xkykai Jan 13, 2026
100b033
fix restore_prognostic_state!
xkykai Jan 13, 2026
17e3dec
Import prognostic state functions in averaged_specified_times.jl
xkykai Jan 13, 2026
d600239
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Jan 16, 2026
9e059aa
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Feb 12, 2026
d5e2188
reduce tests to core numerical tests
xkykai Feb 13, 2026
aaaf6af
rename function to overlap_tolerance
xkykai Feb 13, 2026
761ffb8
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Feb 13, 2026
23c4018
fix newline
xkykai Feb 13, 2026
9abfaa5
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Feb 13, 2026
bda91a5
fix tolerance
xkykai Feb 14, 2026
3af4444
Merge branch 'xk/enable-array-averagedtimeinterval' of https://github…
xkykai Feb 14, 2026
f0f03b5
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Feb 16, 2026
4e51b47
Merge branch 'main' into xk/enable-array-averagedtimeinterval
xkykai Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 138 additions & 27 deletions src/OutputWriters/windowed_time_average.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Oceananigans.Diagnostics: AbstractDiagnostic
using Oceananigans.OutputWriters: fetch_output
using Oceananigans.Utils: AbstractSchedule, prettytime
using Oceananigans.TimeSteppers: Clock
using Dates: Period, Second, value

import Oceananigans: run_diagnostic!
import Oceananigans.Utils: TimeInterval, SpecifiedTimes
Expand All @@ -12,11 +13,11 @@ import Oceananigans.Fields: location, indices, set!

Container for parameters that configure and handle time-averaged output.
"""
mutable struct AveragedTimeInterval <: AbstractSchedule
interval :: Float64
window :: Float64
mutable struct AveragedTimeInterval{I, T} <: AbstractSchedule
interval :: I
window :: I
stride :: Int
first_actuation_time :: Float64
first_actuation_time :: T
actuations :: Int
collecting :: Bool
end
Expand Down Expand Up @@ -91,30 +92,84 @@ function (sch::AveragedTimeInterval)(model)
return scheduled
end
initialize_schedule!(sch::AveragedTimeInterval, clock) = nothing
outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window
outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window
end_of_window(sch::AveragedTimeInterval, clock) = clock.time >= next_actuation_time(sch)

TimeInterval(sch::AveragedTimeInterval) = TimeInterval(sch.interval)
Base.copy(sch::AveragedTimeInterval) = AveragedTimeInterval(sch.interval, window=sch.window, stride=sch.stride)



"""
mutable struct AveragedSpecifiedTimes <: AbstractSchedule

A schedule for averaging over windows that precede SpecifiedTimes.
"""
mutable struct AveragedSpecifiedTimes <: AbstractSchedule
specified_times :: SpecifiedTimes
window :: Float64
mutable struct AveragedSpecifiedTimes{S, W} <: AbstractSchedule
specified_times :: S
window :: W
stride :: Int
collecting :: Bool
end

const VaryingWindowAveragedSpecifiedTimes = AveragedSpecifiedTimes{<:Any, <:Vector}

AveragedSpecifiedTimes(specified_times::SpecifiedTimes; window, stride=1) =
AveragedSpecifiedTimes(specified_times, window, stride, false)

AveragedSpecifiedTimes(times; kw...) = AveragedSpecifiedTimes(SpecifiedTimes(times); kw...)
AveragedSpecifiedTimes(times; window, kw...) = AveragedSpecifiedTimes(times, window; kw...)

determine_epsilon(eltype) = 0
determine_epsilon(::Type{T}) where T <: AbstractFloat = eps(T)
determine_epsilon(::Type{<:Period}) = Second(0)

const NumberTypeWindows = Union{Number, Vector{<:Number}}
const PeriodTypeWindows = Union{Period, Vector{<:Period}}


validate_windows(times, window) = nothing # Fallback method

function validate_windows(times, window::NumberTypeWindows)
tol = 10 * determine_epsilon(eltype(times))

gaps = diff(vcat(0, times)) # Prepend 0 to check first window against t=0
any(gaps .- window .< -tol) && throw(ArgumentError("Averaging windows overlap: some gaps between specified times are less than the window size."))
end

function validate_windows(times, window::PeriodTypeWindows)
if length(times) >= 2
window_starts = times .- window
prev_window_ends = times[1:end-1]
any(window_starts[2:end] .< prev_window_ends) && throw(ArgumentError("Averaging windows overlap: some gaps between specified times are less than the window size."))
end

# Note: We cannot check if the first window extends before the simulation start
# because the model clock is not available at construction time
end

function AveragedSpecifiedTimes(times, window::Vector; kw...)
length(window) == length(times) || throw(ArgumentError("When providing a vector of windows, its length $(length(window)) must match the number of specified times $(length(times))."))
perm = sortperm(times)
sorted_times = times[perm]
sorted_window = window[perm]

# Check for overlapping windows
validate_windows(sorted_times, sorted_window)

return AveragedSpecifiedTimes(SpecifiedTimes(sorted_times); window=sorted_window, kw...)
end

function AveragedSpecifiedTimes(times, window; kw...)
specified_times = SpecifiedTimes(times)

# Check for overlapping windows (scalar window case)
if length(specified_times.times) > 1
validate_windows(specified_times.times, window)
end

return AveragedSpecifiedTimes(specified_times; window, kw...)
end

get_next_window(schedule::VaryingWindowAveragedSpecifiedTimes) = schedule.window[schedule.specified_times.previous_actuation + 1]
get_next_window(schedule) = schedule.window

function (schedule::AveragedSpecifiedTimes)(model)
time = model.clock.time
Expand All @@ -123,7 +178,7 @@ function (schedule::AveragedSpecifiedTimes)(model)
next > length(schedule.specified_times.times) && return false

next_time = schedule.specified_times.times[next]
window = schedule.window
window = get_next_window(schedule)

schedule.collecting || time >= next_time - window
end
Expand All @@ -134,7 +189,8 @@ function outside_window(schedule::AveragedSpecifiedTimes, clock)
next = schedule.specified_times.previous_actuation + 1
next > length(schedule.specified_times.times) && return true
next_time = schedule.specified_times.times[next]
return clock.time < next_time - schedule.window
window = get_next_window(schedule)
return clock.time <= next_time - window
end

function end_of_window(schedule::AveragedSpecifiedTimes, clock)
Expand All @@ -144,22 +200,27 @@ function end_of_window(schedule::AveragedSpecifiedTimes, clock)
return clock.time >= next_time
end

TimeInterval(sch::AveragedSpecifiedTimes) = TimeInterval(sch.specified_times.times)
Base.copy(sch::AveragedSpecifiedTimes) = AveragedSpecifiedTimes(copy(sch.specified_times); window=sch.window, stride=sch.stride)

next_actuation_time(sch::AveragedSpecifiedTimes) = Oceananigans.Utils.next_actuation_time(sch.specified_times)

#####
##### WindowedTimeAverage
#####

mutable struct WindowedTimeAverage{OP, R, S} <: AbstractDiagnostic
mutable struct WindowedTimeAverage{OP, R, T, S} <: AbstractDiagnostic
result :: R
operand :: OP
window_start_time :: Float64
window_start_time :: T
window_start_iteration :: Int
previous_collection_time :: Float64
previous_collection_time :: T
schedule :: S
fetch_operand :: Bool
end

const IntervalWindowedTimeAverage = WindowedTimeAverage{<:Any, <:Any, <:AveragedTimeInterval}
const SpecifiedWindowedTimeAverage = WindowedTimeAverage{<:Any, <:Any, <:AveragedSpecifiedTimes}
const IntervalWindowedTimeAverage = WindowedTimeAverage{<:Any, <:Any, <:Any, <:AveragedTimeInterval}
const SpecifiedWindowedTimeAverage = WindowedTimeAverage{<:Any, <:Any, <:Any, <:AveragedSpecifiedTimes}

stride(wta::IntervalWindowedTimeAverage) = wta.schedule.stride
stride(wta::SpecifiedWindowedTimeAverage) = wta.schedule.stride
Expand All @@ -168,7 +229,7 @@ stride(wta::SpecifiedWindowedTimeAverage) = wta.schedule.stride
WindowedTimeAverage(operand, model=nothing; schedule)

Returns an object for computing running averages of `operand` over `schedule.window` and
recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval`.
recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval` or `AveragedSpecifiedTimes`.
During the collection period, averages are computed every `schedule.stride` iteration.

`operand` may be a `Oceananigans.Field` or a function that returns an array or scalar.
Expand All @@ -186,9 +247,15 @@ function WindowedTimeAverage(operand, model=nothing; schedule, fetch_operand=tru
result .= operand
end

return WindowedTimeAverage(result, operand, 0.0, 0, 0.0, schedule, fetch_operand)
time = isnothing(model) ? get_default_time(schedule) : model.clock.time

return WindowedTimeAverage(result, operand, time, 0, time, schedule, fetch_operand)
end

# Helper functions to get default time based on schedule type
get_default_time(schedule::AveragedTimeInterval) = zero(typeof(schedule.interval))
get_default_time(schedule::AveragedSpecifiedTimes) = zero(eltype(schedule.specified_times.times))

# Time-averaging doesn't change spatial location
location(wta::WindowedTimeAverage) = location(wta.operand)
indices(wta::WindowedTimeAverage) = indices(wta.operand)
Expand All @@ -213,12 +280,15 @@ function accumulate_result!(wta, model)
return accumulate_result!(wta, model.clock, integrand)
end

period_to_number(p::Period) = value(p)
period_to_number(n::Number) = n

function accumulate_result!(wta, clock::Clock, integrand=wta.operand)
# Time increment:
Δt = clock.time - wta.previous_collection_time
Δt = period_to_number(clock.time - wta.previous_collection_time)
# Time intervals:
T_current = clock.time - wta.window_start_time
T_previous = wta.previous_collection_time - wta.window_start_time
T_current = period_to_number(clock.time - wta.window_start_time)
T_previous = period_to_number(wta.previous_collection_time - wta.window_start_time)

# Accumulate left Riemann sum
@. wta.result = (wta.result * T_previous + integrand * Δt) / T_current
Expand Down Expand Up @@ -261,6 +331,40 @@ function advance_time_average!(wta::WindowedTimeAverage, model)
return nothing
end

function advance_time_average!(wta::SpecifiedWindowedTimeAverage, model)

unscheduled = model.clock.iteration == 0 || outside_window(wta.schedule, model.clock)
if !(unscheduled)
if !(wta.schedule.collecting)
# Zero out result to begin new accumulation window
wta.result .= 0

# Begin collecting window-averaged increments
wta.schedule.collecting = true

wta.window_start_time = next_actuation_time(wta.schedule) - get_next_window(wta.schedule)
wta.previous_collection_time = wta.window_start_time
wta.window_start_iteration = model.clock.iteration - 1
# @info "t $(prettytime(model.clock.time)), next actuation time: $(prettytime(next_actuation_time(wta.schedule))), window $(prettytime(wta.schedule.window))"
end

if end_of_window(wta.schedule, model.clock)
accumulate_result!(wta, model)
# Save averaging start time and the initial data collection time
wta.schedule.collecting = false
wta.schedule.specified_times.previous_actuation += 1

elseif mod(model.clock.iteration - wta.window_start_iteration, stride(wta)) == 0
accumulate_result!(wta, model)
else
# Off stride, so do nothing.
end

end
return nothing
end


# So it can be used as a Diagnostic
run_diagnostic!(wta::WindowedTimeAverage, model) = advance_time_average!(wta, model)

Expand All @@ -271,8 +375,14 @@ Base.summary(schedule::AveragedTimeInterval) = string("AveragedTimeInterval(",
"stride=", schedule.stride, ", ",
"interval=", prettytime(schedule.interval), ")")

Base.summary(schedule::AveragedSpecifiedTimes) = string("AveragedSpecifiedTimes(",
"window=", prettytime(schedule.window), ", ",
"stride=", schedule.stride, ", ",
"times=", schedule.specified_times, ")")

show_averaging_schedule(schedule) = ""
show_averaging_schedule(schedule::AveragedTimeInterval) = string(" averaged on ", summary(schedule))
show_averaging_schedule(schedule::AveragedSpecifiedTimes) = string(" averaged on ", summary(schedule))

output_averaging_schedule(output::WindowedTimeAverage) = output.schedule

Expand All @@ -282,6 +392,8 @@ output_averaging_schedule(output::WindowedTimeAverage) = output.schedule

time_average_outputs(schedule, outputs, model) = schedule, outputs # fallback

const AveragedTimeSchedule = Union{AveragedTimeInterval, AveragedSpecifiedTimes}

"""
time_average_outputs(schedule::AveragedTimeInterval, outputs, model, field_slicer)

Expand All @@ -290,17 +402,16 @@ Wrap each `output` in a `WindowedTimeAverage` on the time-averaged `schedule` an
Returns the `TimeInterval` associated with `schedule` and a `NamedTuple` or `Dict` of the wrapped
outputs.
"""
function time_average_outputs(schedule::AveragedTimeInterval, outputs::Dict, model)
function time_average_outputs(schedule::AveragedTimeSchedule, outputs::Dict, model)
averaged_outputs = Dict(name => WindowedTimeAverage(output, model; schedule=copy(schedule))
for (name, output) in outputs)

return TimeInterval(schedule), averaged_outputs
end

function time_average_outputs(schedule::AveragedTimeInterval, outputs::NamedTuple, model)
function time_average_outputs(schedule::AveragedTimeSchedule, outputs::NamedTuple, model)
averaged_outputs = NamedTuple(name => WindowedTimeAverage(outputs[name], model; schedule=copy(schedule))
for name in keys(outputs))

return TimeInterval(schedule), averaged_outputs
end

end
1 change: 1 addition & 0 deletions src/Utils/prettytime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ function prettytimeunits(t, longform=true)
end

prettytime(dt::AbstractTime) = "$dt"
prettytime(t::Array) = prettytime.(t)
2 changes: 2 additions & 0 deletions src/Utils/schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ function specified_times_str(st)
return string(str, "]")
end

Base.copy(st::SpecifiedTimes) = SpecifiedTimes(copy(st.times), st.previous_actuation)

#####
##### ConsecutiveIterations
#####
Expand Down
12 changes: 12 additions & 0 deletions src/Utils/times_and_datetimes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,23 @@ end
@inline add_time_interval(base::AbstractTime, interval::Number, count=1) = base + seconds_to_nanosecond(interval * count)
@inline add_time_interval(base::AbstractTime, interval::Period, count=1) = base + count * interval

@inline add_time_interval(base, interval::Array{<:Number}, count=1) = interval[count]
@inline add_time_interval(base, interval::Array{<:Dates.Period}, count=1) = seconds_to_nanosecond(interval[count])
@inline add_time_interval(base, interval::Array{Dates.DateTime}, count=1) = interval[count]

function period_type(interval::Number)
FT = Oceananigans.defaults.FloatType
return FT
end

function period_type(interval::Array{<:Number})
FT = Oceananigans.defaults.FloatType
return Array{FT, 1}
end

period_type(interval::Dates.Period) = typeof(interval)
period_type(interval) = typeof(interval)

time_type(interval::Number) = typeof(interval)
time_type(interval::Dates.Period) = Dates.DateTime
time_type(interval::Array) = eltype(interval)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ CUDA.allowscalar() do
include("test_implicit_diffusion_diagnostic.jl")
include("test_output_writers.jl")
include("test_output_readers.jl")
include("test_averaged_specified_times.jl")
include("test_set_field_time_series.jl")
end
end
Expand Down
Loading