Skip to content

Commit 6ea30f5

Browse files
committed
Refactor AveragedSpecifiedTimes to support varying window types and improve overlap validation
1 parent 6d16825 commit 6ea30f5

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

src/OutputWriters/windowed_time_average.jl

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Oceananigans.Diagnostics: AbstractDiagnostic
22
using Oceananigans.OutputWriters: fetch_output
33
using Oceananigans.Utils: AbstractSchedule, prettytime
44
using Oceananigans.TimeSteppers: Clock
5+
using Dates: Period
56

67
import Oceananigans: run_diagnostic!
78
import Oceananigans.Utils: TimeInterval, SpecifiedTimes
@@ -102,63 +103,54 @@ Base.copy(sch::AveragedTimeInterval) = AveragedTimeInterval(sch.interval, window
102103
103104
A schedule for averaging over windows that precede SpecifiedTimes.
104105
"""
105-
mutable struct AveragedSpecifiedTimes{W <: Union{Float64, Vector{Float64}}} <: AbstractSchedule
106+
mutable struct AveragedSpecifiedTimes{W} <: AbstractSchedule
106107
specified_times :: SpecifiedTimes
107108
window :: W
108109
stride :: Int
109110
collecting :: Bool
110111
end
111112

112-
const VaryingWindowAveragedSpecifiedTimes = AveragedSpecifiedTimes{Vector{Float64}}
113+
const VaryingWindowAveragedSpecifiedTimes = AveragedSpecifiedTimes{<:Vector}
113114

114115
AveragedSpecifiedTimes(specified_times::SpecifiedTimes; window, stride=1) =
115116
AveragedSpecifiedTimes(specified_times, window, stride, false)
116117

117118
AveragedSpecifiedTimes(times; window, kw...) = AveragedSpecifiedTimes(times, window; kw...)
118119

119-
function AveragedSpecifiedTimes(times, window::Vector{Float64}; kw...)
120+
function determine_epsilon(eltype)
121+
if eltype <: AbstractFloat
122+
return eps(eltype)
123+
elseif eltype <: Period
124+
return Second(0)
125+
else
126+
return 0
127+
end
128+
end
129+
130+
function AveragedSpecifiedTimes(times, window::Vector; kw...)
120131
length(window) == length(times) || throw(ArgumentError("When providing a vector of windows, its length $(length(window)) must match the number of specified times $(length(times))."))
121132
perm = sortperm(times)
122133
sorted_times = times[perm]
123134
sorted_window = window[perm]
124135
time_diff = diff(vcat(0, sorted_times))
125136

126-
any(time_diff .- sorted_window .< -eps(eltype(window))) && throw(ArgumentError("Averaging windows overlap. Ensure that for each specified time tᵢ, tᵢ - windowᵢ ≥ tᵢ₋₁."))
137+
epsilon = determine_epsilon(eltype(window))
138+
any(time_diff .- sorted_window .< -epsilon) && throw(ArgumentError("Averaging windows overlap. Ensure that for each specified time tᵢ, tᵢ - windowᵢ ≥ tᵢ₋₁."))
127139

128140
return AveragedSpecifiedTimes(SpecifiedTimes(sorted_times); window=sorted_window, kw...)
129141
end
130142

131-
function AveragedSpecifiedTimes(times, window::Float64; kw...)
143+
function AveragedSpecifiedTimes(times, window::Union{<:Number, <:Period}; kw...)
132144
sorted_times = sort(times)
133145
time_diff = diff(vcat(0, sorted_times))
134146

135-
any(time_diff .- window .< -eps(typeof(window))) && throw(ArgumentError("Averaging window $window is too large and causes overlapping windows. Ensure that for each specified time tᵢ, tᵢ - window ≥ tᵢ₋₁."))
147+
epsilon = determine_epsilon(typeof(window))
148+
149+
any(time_diff .- window .< -epsilon) && throw(ArgumentError("Averaging window $window is too large and causes overlapping windows. Ensure that for each specified time tᵢ, tᵢ - window ≥ tᵢ₋₁."))
136150

137151
return AveragedSpecifiedTimes(SpecifiedTimes(times); window, kw...)
138152
end
139153

140-
# function AveragedSpecifiedTimes(times; window, kw...)
141-
# perm = sortperm(times)
142-
# sorted_times = times[perm]
143-
# time_diff = diff(vcat(0, sorted_times))
144-
145-
# if window isa Vector{Float64}
146-
# length(window) == length(times) || throw(ArgumentError("When providing a vector of windows, its length $(length(window)) must match the number of specified times $(length(times))."))
147-
148-
# sorted_window = window[perm]
149-
# @info "timediff", time_diff
150-
# @info "sortedwindow", sorted_window
151-
152-
# any(time_diff .- sorted_window .< -eps(eltype(window))) && throw(ArgumentError("Averaging windows overlap. Ensure that for each specified time tᵢ, tᵢ - windowᵢ ≥ tᵢ₋₁."))
153-
# return AveragedSpecifiedTimes(SpecifiedTimes(sorted_times); window=sorted_window, kw...)
154-
# elseif window isa Number
155-
# any(time_diff .- window .< -eps(typeof(window))) && throw(ArgumentError("Averaging window $window is too large and causes overlapping windows. Ensure that for each specified time tᵢ, tᵢ - window ≥ tᵢ₋₁."))
156-
# return AveragedSpecifiedTimes(SpecifiedTimes(times); window, kw...)
157-
# else
158-
# throw(ArgumentError("window must be a Float64 or a Vector{Float64}, got $(typeof(window))"))
159-
# end
160-
# end
161-
162154
get_next_window(schedule::VaryingWindowAveragedSpecifiedTimes) = schedule.window[schedule.specified_times.previous_actuation + 1]
163155
get_next_window(schedule::AveragedSpecifiedTimes) = schedule.window
164156

0 commit comments

Comments
 (0)