diff --git a/src/OutputWriters/windowed_time_average.jl b/src/OutputWriters/windowed_time_average.jl index c0c98d2692..90cb6387d4 100644 --- a/src/OutputWriters/windowed_time_average.jl +++ b/src/OutputWriters/windowed_time_average.jl @@ -2,6 +2,7 @@ using Oceananigans.Diagnostics: AbstractDiagnostic using Oceananigans.OutputWriters: fetch_output using Oceananigans.Utils: AbstractSchedule, prettytime using Oceananigans.TimeSteppers: Clock +using Dates: Period import Oceananigans: run_diagnostic! import Oceananigans.Utils: TimeInterval, SpecifiedTimes @@ -12,11 +13,11 @@ import Oceananigans.Fields: location, indices, set! Container for parameters that configure and handle time-averaged output. """ -mutable struct AveragedTimeInterval <: AbstractSchedule - interval :: Float64 - window :: Float64 +mutable struct AveragedTimeInterval{I, T} <: AbstractSchedule + interval :: I + window :: I stride :: Int - first_actuation_time :: Float64 + first_actuation_time :: T actuations :: Int collecting :: Bool end @@ -91,30 +92,61 @@ function (sch::AveragedTimeInterval)(model) return scheduled end initialize_schedule!(sch::AveragedTimeInterval, clock) = nothing -outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window +outside_window(sch::AveragedTimeInterval, clock) = clock.time <= next_actuation_time(sch) - sch.window end_of_window(sch::AveragedTimeInterval, clock) = clock.time >= next_actuation_time(sch) TimeInterval(sch::AveragedTimeInterval) = TimeInterval(sch.interval) Base.copy(sch::AveragedTimeInterval) = AveragedTimeInterval(sch.interval, window=sch.window, stride=sch.stride) - - """ mutable struct AveragedSpecifiedTimes <: AbstractSchedule A schedule for averaging over windows that precede SpecifiedTimes. """ -mutable struct AveragedSpecifiedTimes <: AbstractSchedule - specified_times :: SpecifiedTimes - window :: Float64 +mutable struct AveragedSpecifiedTimes{S<:SpecifiedTimes, W} <: AbstractSchedule + specified_times :: S + window :: W stride :: Int collecting :: Bool end +const VaryingWindowAveragedSpecifiedTimes = AveragedSpecifiedTimes{<:Any, <:Vector} + AveragedSpecifiedTimes(specified_times::SpecifiedTimes; window, stride=1) = AveragedSpecifiedTimes(specified_times, window, stride, false) -AveragedSpecifiedTimes(times; kw...) = AveragedSpecifiedTimes(SpecifiedTimes(times); kw...) +AveragedSpecifiedTimes(times; window, kw...) = AveragedSpecifiedTimes(times, window; kw...) + +determine_epsilon(eltype) = 0 +determine_epsilon(::Type{T}) where T <: AbstractFloat = eps(T) +determine_epsilon(::Period) = Second(0) + +function AveragedSpecifiedTimes(times, window::Vector; kw...) + length(window) == length(times) || throw(ArgumentError("When providing a vector of windows, its length $(length(window)) must match the number of specified times $(length(times)).")) + perm = sortperm(times) + sorted_times = times[perm] + sorted_window = window[perm] + time_diff = diff(vcat(0, sorted_times)) + + epsilon = determine_epsilon(eltype(window)) + any(time_diff .- sorted_window .< -epsilon) && throw(ArgumentError("Averaging windows overlap. Ensure that for each specified time tᵢ, tᵢ - windowᵢ ≥ tᵢ₋₁.")) + + return AveragedSpecifiedTimes(SpecifiedTimes(sorted_times); window=sorted_window, kw...) +end + +function AveragedSpecifiedTimes(times, window; kw...) + sorted_times = sort(times) + time_diff = diff(vcat(0, sorted_times)) + + epsilon = determine_epsilon(typeof(window)) + + any(time_diff .- window .< -epsilon) && throw(ArgumentError("Averaging window $window is too large and causes overlapping windows. Ensure that for each specified time tᵢ, tᵢ - window ≥ tᵢ₋₁.")) + + return AveragedSpecifiedTimes(SpecifiedTimes(times); window, kw...) +end + +get_next_window(schedule::VaryingWindowAveragedSpecifiedTimes) = schedule.window[schedule.specified_times.previous_actuation + 1] +get_next_window(schedule) = schedule.window function (schedule::AveragedSpecifiedTimes)(model) time = model.clock.time @@ -123,7 +155,7 @@ function (schedule::AveragedSpecifiedTimes)(model) next > length(schedule.specified_times.times) && return false next_time = schedule.specified_times.times[next] - window = schedule.window + window = get_next_window(schedule) schedule.collecting || time >= next_time - window end @@ -134,7 +166,8 @@ function outside_window(schedule::AveragedSpecifiedTimes, clock) next = schedule.specified_times.previous_actuation + 1 next > length(schedule.specified_times.times) && return true next_time = schedule.specified_times.times[next] - return clock.time < next_time - schedule.window + window = get_next_window(schedule) + return clock.time < next_time - window end function end_of_window(schedule::AveragedSpecifiedTimes, clock) @@ -144,6 +177,11 @@ function end_of_window(schedule::AveragedSpecifiedTimes, clock) return clock.time >= next_time end +TimeInterval(sch::AveragedSpecifiedTimes) = TimeInterval(sch.specified_times.times) +Base.copy(sch::AveragedSpecifiedTimes) = AveragedSpecifiedTimes(copy(sch.specified_times); window=sch.window, stride=sch.stride) + +next_actuation_time(sch::AveragedSpecifiedTimes) = Oceananigans.Utils.next_actuation_time(sch.specified_times) + ##### ##### WindowedTimeAverage ##### @@ -168,7 +206,7 @@ stride(wta::SpecifiedWindowedTimeAverage) = wta.schedule.stride WindowedTimeAverage(operand, model=nothing; schedule) Returns an object for computing running averages of `operand` over `schedule.window` and -recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval`. +recurring on `schedule.interval`, where `schedule` is an `AveragedTimeInterval` or `AveragedSpecifiedTimes`. During the collection period, averages are computed every `schedule.stride` iteration. `operand` may be a `Oceananigans.Field` or a function that returns an array or scalar. @@ -261,6 +299,40 @@ function advance_time_average!(wta::WindowedTimeAverage, model) return nothing end +function advance_time_average!(wta::SpecifiedWindowedTimeAverage, model) + + unscheduled = model.clock.iteration == 0 || outside_window(wta.schedule, model.clock) + if !(unscheduled) + if !(wta.schedule.collecting) + # Zero out result to begin new accumulation window + wta.result .= 0 + + # Begin collecting window-averaged increments + wta.schedule.collecting = true + + wta.window_start_time = next_actuation_time(wta.schedule) - get_next_window(wta.schedule) + wta.previous_collection_time = wta.window_start_time + wta.window_start_iteration = model.clock.iteration - 1 + # @info "t $(prettytime(model.clock.time)), next actuation time: $(prettytime(next_actuation_time(wta.schedule))), window $(prettytime(wta.schedule.window))" + end + + if end_of_window(wta.schedule, model.clock) + accumulate_result!(wta, model) + # Save averaging start time and the initial data collection time + wta.schedule.collecting = false + wta.schedule.specified_times.previous_actuation += 1 + + elseif mod(model.clock.iteration - wta.window_start_iteration, stride(wta)) == 0 + accumulate_result!(wta, model) + else + # Off stride, so do nothing. + end + + end + return nothing +end + + # So it can be used as a Diagnostic run_diagnostic!(wta::WindowedTimeAverage, model) = advance_time_average!(wta, model) @@ -271,8 +343,14 @@ Base.summary(schedule::AveragedTimeInterval) = string("AveragedTimeInterval(", "stride=", schedule.stride, ", ", "interval=", prettytime(schedule.interval), ")") +Base.summary(schedule::AveragedSpecifiedTimes) = string("AveragedSpecifiedTimes(", + "window=", prettytime(schedule.window), ", ", + "stride=", schedule.stride, ", ", + "times=", schedule.specified_times, ")") + show_averaging_schedule(schedule) = "" show_averaging_schedule(schedule::AveragedTimeInterval) = string(" averaged on ", summary(schedule)) +show_averaging_schedule(schedule::AveragedSpecifiedTimes) = string(" averaged on ", summary(schedule)) output_averaging_schedule(output::WindowedTimeAverage) = output.schedule @@ -282,6 +360,8 @@ output_averaging_schedule(output::WindowedTimeAverage) = output.schedule time_average_outputs(schedule, outputs, model) = schedule, outputs # fallback +const AveragedTimeSchedule = Union{AveragedTimeInterval, AveragedSpecifiedTimes} + """ time_average_outputs(schedule::AveragedTimeInterval, outputs, model, field_slicer) @@ -290,17 +370,16 @@ Wrap each `output` in a `WindowedTimeAverage` on the time-averaged `schedule` an Returns the `TimeInterval` associated with `schedule` and a `NamedTuple` or `Dict` of the wrapped outputs. """ -function time_average_outputs(schedule::AveragedTimeInterval, outputs::Dict, model) +function time_average_outputs(schedule::AveragedTimeSchedule, outputs::Dict, model) averaged_outputs = Dict(name => WindowedTimeAverage(output, model; schedule=copy(schedule)) for (name, output) in outputs) return TimeInterval(schedule), averaged_outputs end -function time_average_outputs(schedule::AveragedTimeInterval, outputs::NamedTuple, model) +function time_average_outputs(schedule::AveragedTimeSchedule, outputs::NamedTuple, model) averaged_outputs = NamedTuple(name => WindowedTimeAverage(outputs[name], model; schedule=copy(schedule)) for name in keys(outputs)) return TimeInterval(schedule), averaged_outputs -end - +end \ No newline at end of file diff --git a/src/Utils/prettytime.jl b/src/Utils/prettytime.jl index ac71dbd676..f2e3e20f85 100644 --- a/src/Utils/prettytime.jl +++ b/src/Utils/prettytime.jl @@ -64,3 +64,4 @@ function prettytimeunits(t, longform=true) end prettytime(dt::AbstractTime) = "$dt" +prettytime(t::Array) = prettytime.(t) diff --git a/src/Utils/schedules.jl b/src/Utils/schedules.jl index d7c72763c6..88bab3269c 100644 --- a/src/Utils/schedules.jl +++ b/src/Utils/schedules.jl @@ -234,6 +234,8 @@ function specified_times_str(st) return string(str, "]") end +Base.copy(st::SpecifiedTimes) = SpecifiedTimes(copy(st.times), st.previous_actuation) + ##### ##### ConsecutiveIterations ##### diff --git a/src/Utils/times_and_datetimes.jl b/src/Utils/times_and_datetimes.jl index 7521c64115..a2c0df07f8 100644 --- a/src/Utils/times_and_datetimes.jl +++ b/src/Utils/times_and_datetimes.jl @@ -29,11 +29,19 @@ end @inline add_time_interval(base::AbstractTime, interval::Number, count=1) = base + seconds_to_nanosecond(interval * count) @inline add_time_interval(base::AbstractTime, interval::Period, count=1) = base + count * interval +@inline add_time_interval(base::Number, interval::Array{<:Number}, count=1) = interval[count] + function period_type(interval::Number) FT = Oceananigans.defaults.FloatType return FT end +function period_type(interval::Array{<:Number}) + FT = Oceananigans.defaults.FloatType + return Array{FT, 1} +end + period_type(interval::Dates.Period) = typeof(interval) time_type(interval::Number) = typeof(interval) time_type(interval::Dates.Period) = Dates.DateTime +time_type(interval::Array{<:Number}) = eltype(interval)