Skip to content

Commit cdbe169

Browse files
authored
DatasetRestoring only via Metadata (#486)
* DatasetRestoring only from metadata * better show method for DatasetRestoring and Metadata * typed arg * fix test * better DatasetRestoring show method * better DatasetRestoring show method * more robust Dataset cycling boundaries test * fix start_idx bug + add convenience metadata constructor * test the Metadata convenience constructor * disambiguate Metadata constructors * validate kwargs * fix docstring * it's isnothing * fix test * fix Metadata constructor
1 parent 09f492f commit cdbe169

File tree

4 files changed

+85
-79
lines changed

4 files changed

+85
-79
lines changed

src/DataWrangling/metadata.jl

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,37 @@ end
1717
1818
Metadata holding a specific dataset information.
1919
20-
Arguments
21-
=========
20+
Argument
21+
========
2222
- `variable_name`: a symbol representing the name of the variable (for example, `:temperature`,
2323
`:salinity`, `:u_velocity`, etc)
2424
2525
Keyword Arguments
2626
=================
27-
- `dataset`: The dataset of the dataset. Supported datasets are `ECCO2Monthly()`, `ECCO2Daily()`,
28-
`ECCO4Monthly()`, `EN4Monthly(), `RepeatYearJRA55()`, or `MultiYearJRA55()`.
29-
- `dates`: The dates of the dataset, in a `AbstractCFDateTime` format. Note this can either be a range
30-
or a vector of dates, representing a time-series. For a single date, use [`Metadatum`](@ref).
27+
- `dataset`: Supported datasets are `ECCO2Monthly()`, `ECCO2Daily()`, `ECCO4Monthly()`, `EN4Monthly(),
28+
`RepeatYearJRA55()`, or `MultiYearJRA55()`.
29+
- `dates`: The dates of the dataset (`Dates.AbstractDateTime` or `CFTime.AbstractCFDateTime`).
30+
Note this can either be a range or a vector of dates, representing a time-series.
31+
For a single date, use [`Metadatum`](@ref).
32+
- `start_date`: If `dates = nothing`, we can prescribe the first date of metadata as a date
33+
(`Dates.AbstractDateTime` or `CFTime.AbstractCFDateTime`). `start_date` should lie
34+
within the date range of the dataset. Default: nothing.
35+
- `end_date`: If `dates = nothing`, we can prescribe the last date of metadata as a date
36+
(`Dates.AbstractDateTime` or `CFTime.AbstractCFDateTime`). `end_date` should lie
37+
within the date range of the dataset. Default: nothing.
3138
- `dir`: The directory where the dataset is stored.
3239
"""
3340
function Metadata(variable_name;
3441
dataset,
35-
dates=all_dates(dataset, variable_name)[1:1],
36-
dir=default_download_directory(dataset))
42+
dates = all_dates(dataset, variable_name),
43+
dir = default_download_directory(dataset),
44+
start_date = nothing,
45+
end_date = nothing)
46+
47+
if !isnothing(start_date) && !isnothing(end_date)
48+
@info "Slicing date range within $start_date and $end_date"
49+
dates = compute_native_date_range(dates, start_date, end_date)
50+
end
3751

3852
return Metadata(variable_name, dataset, dates, dir)
3953
end
@@ -54,7 +68,9 @@ function Metadatum(variable_name;
5468
date=first_date(dataset, variable_name),
5569
dir=default_download_directory(dataset))
5670

57-
# TODO: validate that `date` is actually a single date?
71+
date isa Union{CFTime.AbstractCFDateTime, Dates.AbstractDateTime} ||
72+
throw(ArgumentError("date must be Union{Dates.AbstractDateTime, CFTime.AbstractCFDateTime}"))
73+
5874
return Metadata(variable_name, dataset, date, dir)
5975
end
6076

@@ -69,7 +85,7 @@ Base.show(io::IO, metadata::Metadata) =
6985
"├── name: $(metadata.name)", '\n',
7086
"├── dataset: $(metadata.dataset)", '\n',
7187
"├── dates: $(metadata.dates)", '\n',
72-
"└── data directory: $(metadata.dir)")
88+
"└── dir: $(metadata.dir)")
7389

7490
# Treat Metadata as an array to allow iteration over the dates.
7591
Base.length(metadata::Metadata) = length(metadata.dates)
@@ -162,7 +178,7 @@ metadata_filename(metadata) = [metadata_filename(metadatum) for metadatum in met
162178
"""
163179
compute_native_date_range(native_dates, start_date, end_date)
164180
165-
Compute the range of dates that fall within the specified start and end date.
181+
Compute the range of `native_dates` that fall within the specified `start_date` and `end_date`.
166182
"""
167183
function compute_native_date_range(native_dates, start_date, end_date)
168184
if last(native_dates) < end_date
@@ -175,7 +191,7 @@ function compute_native_date_range(native_dates, start_date, end_date)
175191

176192
start_idx = findfirst(x -> x start_date, native_dates)
177193
end_idx = findfirst(x -> x end_date, native_dates)
178-
start_idx = start_idx > 1 ? start_idx - 1 : start_idx
194+
start_idx = (start_idx > 1 && native_dates[start_idx] > start_date) ? start_idx - 1 : start_idx
179195
end_idx = isnothing(end_idx) ? length(native_dates) : end_idx
180196

181197
return native_dates[start_idx:end_idx]

src/DataWrangling/restoring.jl

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,12 @@ end
221221
end
222222

223223
"""
224-
DatasetRestoring(variable_name::Symbol, [ arch_or_grid = CPU(), ];
225-
dataset,
226-
start_date = first_date(dataset, variable_name),
227-
end_date = last_date(dataset, variable_name),
228-
time_indices_in_memory = 2,
229-
time_indexing = Cyclical(),
224+
DatasetRestoring(metadata::Metadata,
225+
arch_or_grid = CPU();
226+
rate,
230227
mask = 1,
231-
rate = 1,
232-
dir = default_download_directory(dataset),
228+
time_indices_in_memory = 2, # Not more than this if we want to use GPU!
229+
time_indexing = Cyclical(),
233230
inpainting = NearestNeighborInpainting(Inf),
234231
cache_inpainted_data = true)
235232
@@ -247,7 +244,7 @@ from the dataset of choice to the simulation grid and time.
247244
Arguments
248245
=========
249246
250-
- `variable_name`: The name of the variable to restore. Choices include:
247+
- `metadata`: The medatada for a dataset variable to restore. Choices for variables include:
251248
* `:temperature`,
252249
* `:salinity`,
253250
* `:u_velocity`,
@@ -260,54 +257,25 @@ Arguments
260257
`arch_or_grid = CPU()` or `arch_or_grid = GPU()`, data is interpolated
261258
on-the-fly when the forcing tendency is computed. Default: CPU().
262259
263-
!!! info "Providing `Metadata` instead of `variable_name`"
264-
Note that `Metadata` may be provided as the first argument instead of `variable_name`.
265-
In this case the `dataset`, `start_date`, and `end_date` kwargs (described below)
266-
cannot be provided since they are inferred from `Metadata`.
267-
268260
Keyword Arguments
269261
=================
270262
271-
- `dataset`: The dataset; required keyword argument if `variable_name` argument is provided.
272-
273-
- `start_date`: The starting date to use for the dataset. Default: `first_date(dataset, variable_name)`.
263+
- `rate`: The restoring rate, i.e., the inverse of the restoring timescale (in s⁻¹).
274264
275-
- `end_date`: The ending date to use for the dataset. Default: `end_date(dataset, variable_name)`.
265+
- `mask`: The mask value. Can be a function of `(x, y, z, time)`, an array, or a number.
276266
277267
- `time_indices_in_memory`: The number of time indices to keep in memory. The number is chosen based on
278268
a trade-off between increased performance (more indices in memory) and reduced
279269
memory footprint (fewer indices in memory). Default: 2.
280270
281-
- `time_indexing`: The time indexing scheme for the field time series.
282-
283-
- `mask`: The mask value. Can be a function of `(x, y, z, time)`, an array, or a number.
284-
285-
- `rate`: The restoring rate, i.e., the inverse of the restoring timescale (in s⁻¹).
286-
287-
- `dir`: The directory where the native data is located. If the data does not exist it will
288-
be automatically downloaded. Default: `default_download_directory(dataset)`.
271+
- `time_indexing`: The time indexing scheme for the field time series. Default: `Cyclical()`.
289272
290273
- `inpainting`: inpainting algorithm, see [`inpaint_mask!`](@ref). Default: `NearestNeighborInpainting(Inf)`.
291274
292275
- `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving.
293276
Default: `true`.
294277
"""
295-
function DatasetRestoring(variable_name::Symbol,
296-
arch_or_grid = CPU();
297-
dataset,
298-
dir = default_download_directory(dataset),
299-
start_date = first_date(dataset, variable_name),
300-
end_date = last_date(dataset, variable_name),
301-
kw...)
302-
303-
native_dates = all_dates(dataset, variable_name)
304-
dates = compute_native_date_range(native_dates, start_date, end_date)
305-
metadata = Metadata(variable_name, dataset, dates, dir)
306-
307-
return DatasetRestoring(metadata, arch_or_grid; kw...)
308-
end
309-
310-
function DatasetRestoring(metadata,
278+
function DatasetRestoring(metadata::Metadata,
311279
arch_or_grid = CPU();
312280
rate,
313281
mask = 1,
@@ -316,6 +284,8 @@ function DatasetRestoring(metadata,
316284
inpainting = NearestNeighborInpainting(Inf),
317285
cache_inpainted_data = true)
318286

287+
download_dataset(metadata)
288+
319289
fts = FieldTimeSeries(metadata, arch_or_grid;
320290
time_indices_in_memory,
321291
time_indexing,
@@ -334,13 +304,17 @@ function DatasetRestoring(metadata,
334304
return DatasetRestoring(fts, maybe_native_grid, mask, field_name, rate)
335305
end
336306

337-
function Base.show(io::IO, p::DatasetRestoring)
307+
function Base.show(io::IO, dsr::DatasetRestoring)
338308
print(io, "DatasetRestoring:", '\n',
339-
"├── restored variable: ", summary(p.variable_name), '\n',
340-
"├── restoring dataset: ", summary(p.field_time_series.backend.metadata), '\n',
341-
"├── restoring rate: ", p.rate, '\n',
342-
"├── mask: ", summary(p.mask), '\n',
343-
"└── grid: ", summary(p.native_grid))
309+
"├── variable_name: ", summary(dsr.variable_name), '\n',
310+
"├── rate: ", dsr.rate, '\n',
311+
"├── field_time_series: ", summary(dsr.field_time_series), '\n',
312+
"│ ├── dataset: ", summary(dsr.field_time_series.backend.metadata.dataset), '\n',
313+
"│ ├── dates: ", dsr.field_time_series.backend.metadata.dates, '\n',
314+
"│ ├── time_indexing: ", summary(dsr.field_time_series.time_indexing), '\n',
315+
"│ └── dir: ", dsr.field_time_series.backend.metadata.dir, '\n',
316+
"├── mask: ", summary(dsr.mask), '\n',
317+
"└── native_grid: ", summary(dsr.native_grid))
344318
end
345319

346320
regularize_forcing(forcing::DatasetRestoring, field, field_name, model_field_names) = forcing

test/test_ecco4_en4.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ for arch in test_architectures, dataset in test_datasets
173173
z = (z₁, 0))
174174

175175
for name in (:temperature, :salinity)
176-
var_restoring = DatasetRestoring(name, arch; dataset, start_date, end_date, mask, inpainting, rate=1/1000)
176+
metadata = Metadata(name; dates, dataset)
177+
var_restoring = DatasetRestoring(metadata, arch; mask, inpainting, rate=1/1000)
177178

178179
fill!(var_restoring.field_time_series[1], 1.0)
179180
fill!(var_restoring.field_time_series[2], 1.0)
@@ -210,7 +211,8 @@ for arch in test_architectures, dataset in test_datasets
210211
true
211212
end
212213

213-
forcing_T = DatasetRestoring(:temperature, arch; dataset, start_date, end_date, inpainting, rate=1/1000)
214+
Tmetadata = Metadata(:temperature; dates, dataset)
215+
forcing_T = DatasetRestoring(Tmetadata, arch; inpainting, rate=1/1000)
214216

215217
ocean = ocean_simulation(grid; forcing = (; T = forcing_T), verbose=false)
216218

@@ -233,30 +235,44 @@ for arch in test_architectures, dataset in test_datasets
233235
end_date = DateTime(1993, 5, 1)
234236
dates = start_date : Month(1) : end_date
235237

236-
T_restoring = DatasetRestoring(:temperature, arch; dataset, start_date, end_date, inpainting, rate=1/1000)
238+
time_indices_in_memory = 2
237239

238-
times = native_times(T_restoring.field_time_series.backend.metadata)
239-
ocean = ocean_simulation(grid, forcing = (; T = T_restoring))
240+
Tmetadata1 = Metadata(:temperature; dates, dataset)
241+
Tmetadata2 = Metadata(:temperature; start_date, end_date, dataset)
240242

241-
ocean.model.clock.time = times[3] + 2 * Units.days
242-
update_state!(ocean.model)
243+
for Tmetadata in (Tmetadata1, Tmetadata2)
244+
T_restoring = DatasetRestoring(Tmetadata, arch; time_indices_in_memory, inpainting, rate=1/1000)
243245

244-
@test T_restoring.field_time_series.backend.start == 3
246+
times = native_times(T_restoring.field_time_series.backend.metadata)
247+
ocean = ocean_simulation(grid, forcing = (; T = T_restoring))
245248

246-
# Compile
247-
time_step!(ocean)
249+
# start a bit after time_index
250+
time_index = 3
251+
ocean.model.clock.time = times[time_index] + 2 * Units.days
252+
update_state!(ocean.model)
248253

249-
# Try stepping out of the dataset bounds
250-
ocean.model.clock.time = last(times) + 2 * Units.days
254+
@test time_indices(T_restoring.field_time_series) ==
255+
Tuple(range(time_index, length=time_indices_in_memory))
251256

252-
update_state!(ocean.model)
257+
@test T_restoring.field_time_series.backend.start == time_index
253258

254-
@test begin
259+
# Compile
255260
time_step!(ocean)
256-
true
257-
end
258261

259-
# The backend has cycled to the end
260-
@test time_indices(T_restoring.field_time_series) == (6, 1)
262+
# Try stepping out of the dataset bounds
263+
# start a bit after last time_index
264+
ocean.model.clock.time = last(times) + 2 * Units.days
265+
266+
update_state!(ocean.model)
267+
268+
@test begin
269+
time_step!(ocean)
270+
true
271+
end
272+
273+
# The backend has cycled to the end
274+
@test time_indices(T_restoring.field_time_series) ==
275+
mod1.(Tuple(range(length(times), length=time_indices_in_memory)), length(times))
276+
end
261277
end
262278
end

test/test_jra55.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ using ClimaOcean.OceanSeaIceModels: PrescribedAtmosphere
142142
backend = JRA55NetCDFBackend(10)
143143
Ta = JRA55FieldTimeSeries(:temperature; dataset, start_date, end_date, backend)
144144

145-
@test Second(end_date - start_date).value Ta.times[end-1] - Ta.times[1]
145+
@test Second(end_date - start_date).value Ta.times[end] - Ta.times[1]
146146

147147
# Test we can access all the data
148148
for t in eachindex(Ta.times)

0 commit comments

Comments
 (0)