Skip to content

Commit 696938a

Browse files
authored
Merge pull request #181 from brews/fix_validate_memory
Hack fix to validate OOM errors
2 parents ae0e87d + 22357f9 commit 696938a

File tree

3 files changed

+62
-68
lines changed

3 files changed

+62
-68
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88
### Added
99
- Add basic CI/CD test and build status badges to README. (PR #182, @brews)
10+
### Fixed
11+
- Fix dodola validate-dataset OOM on small workers without dask-distributed. (PR #181, @brews)
1012

1113
## [0.17.0] - 2022-02-17
1214
### Changed

dodola/core.py

Lines changed: 7 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import warnings
88
import logging
9-
import dask
109
import numpy as np
1110
import xarray as xr
1211
from xclim import sdba, set_options
@@ -637,71 +636,14 @@ def apply_precip_ceiling(ds, ceiling):
637636
return ds_corrected
638637

639638

640-
def validate_dataset(ds, var, data_type, time_period="future"):
641-
"""
642-
Validate a Dataset. Valid for CMIP6, bias corrected and downscaled.
643-
644-
Raises AssertionError when validation fails.
645-
646-
Parameters
647-
----------
648-
ds : xr.Dataset
649-
var : {"tasmax", "tasmin", "dtr", "pr"}
650-
Variable in Dataset to validate.
651-
data_type : {"cmip6", "bias_corrected", "downscaled"}
652-
Type of data output to validate.
653-
time_period : {"historical", "future"}
654-
Time period of data that will be validated.
655-
"""
656-
# This is pretty rough but works to communicate the idea.
657-
# Consider having failed tests raise something like ValidationError rather
658-
# than AssertionErrors.
659-
660-
# These only read in Zarr Store metadata -- not memory intensive.
661-
_test_variable_names(ds, var)
662-
_test_timesteps(ds, data_type, time_period)
663-
664-
# Other test are done on annual selections with dask.delayed to
665-
# avoid large memory errors. xr.map_blocks had trouble with this.
666-
@dask.delayed
667-
def memory_intensive_tests(ds, v, t):
668-
d = ds.sel(time=str(t))
669-
670-
_test_for_nans(d, v)
671-
672-
if v == "tasmin":
673-
_test_temp_range(d, v)
674-
elif v == "tasmax":
675-
_test_temp_range(d, v)
676-
elif v == "dtr":
677-
_test_dtr_range(d, v, data_type)
678-
_test_negative_values(d, v)
679-
elif v == "pr":
680-
_test_negative_values(d, v)
681-
_test_maximum_precip(d, v)
682-
else:
683-
raise ValueError(f"Argument {v=} not recognized")
684-
685-
# Assumes error thrown if had problem before this.
686-
return True
687-
688-
results = []
689-
for t in np.unique(ds["time"].dt.year.data):
690-
logger.debug(f"Validating year {t}")
691-
results.append(memory_intensive_tests(ds, var, t))
692-
results = dask.compute(*results)
693-
assert all(results) # Likely don't need this
694-
return True
695-
696-
697-
def _test_for_nans(ds, var):
639+
def test_for_nans(ds, var):
698640
"""
699641
Tests for presence of NaNs
700642
"""
701643
assert ds[var].isnull().sum() == 0, "there are nans!"
702644

703645

704-
def _test_timesteps(ds, data_type, time_period):
646+
def test_timesteps(ds, data_type, time_period):
705647
"""
706648
Tests that Dataset contains the correct number of timesteps (number of days on a noleap calendar)
707649
for the data_type/time_period combination.
@@ -763,14 +705,14 @@ def _test_timesteps(ds, data_type, time_period):
763705
)
764706

765707

766-
def _test_variable_names(ds, var):
708+
def test_variable_names(ds, var):
767709
"""
768710
Test that the correct variable name exists in the file
769711
"""
770712
assert var in ds.var(), "{} not in Dataset".format(var)
771713

772714

773-
def _test_temp_range(ds, var):
715+
def test_temp_range(ds, var):
774716
"""
775717
Ensure temperature values are in a valid range
776718
"""
@@ -781,7 +723,7 @@ def _test_temp_range(ds, var):
781723
), "{} values are invalid".format(var)
782724

783725

784-
def _test_dtr_range(ds, var, data_type):
726+
def test_dtr_range(ds, var, data_type):
785727
"""
786728
Ensure DTR values are in a valid range
787729
Test polar values separately since some polar values can be much higher post-bias correction.
@@ -830,7 +772,7 @@ def _test_dtr_range(ds, var, data_type):
830772
), "diurnal temperature range max is {} for non-polar regions".format(non_polar_max)
831773

832774

833-
def _test_negative_values(ds, var):
775+
def test_negative_values(ds, var):
834776
"""
835777
Tests for presence of negative values
836778
"""
@@ -839,7 +781,7 @@ def _test_negative_values(ds, var):
839781
assert neg_values == 0, "there are {} negative values!".format(neg_values)
840782

841783

842-
def _test_maximum_precip(ds, var):
784+
def test_maximum_precip(ds, var):
843785
"""
844786
Tests that max precip is reasonable
845787
"""

dodola/services.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44
import json
55
import logging
6+
import dask
67
from dodola.core import (
78
xesmf_regrid,
89
standardize_gcm,
@@ -12,12 +13,18 @@
1213
adjust_quantiledeltamapping,
1314
train_analogdownscaling,
1415
adjust_analogdownscaling,
15-
validate_dataset,
1616
dtr_floor,
1717
non_polar_dtr_ceiling,
1818
apply_precip_ceiling,
1919
xclim_units_any2pint,
2020
xclim_units_pint2cf,
21+
test_for_nans,
22+
test_variable_names,
23+
test_timesteps,
24+
test_temp_range,
25+
test_dtr_range,
26+
test_negative_values,
27+
test_maximum_precip,
2128
)
2229
import dodola.repository as storage
2330

@@ -701,6 +708,12 @@ def adjust_maximum_precipitation(x, out, threshold=3000.0):
701708
def validate(x, var, data_type, time_period):
702709
"""Performs validation on an input dataset
703710
711+
Valid for CMIP6, bias corrected and downscaled. Raises AssertionError when
712+
validation fails.
713+
714+
This function performs more memory-intensive tests by reading input data
715+
and subsetting to each year in the "time" dimension.
716+
704717
Parameters
705718
----------
706719
x : str
@@ -714,6 +727,43 @@ def validate(x, var, data_type, time_period):
714727
Time period that input data should cover, used in validating the number of timesteps
715728
in conjunction with the data type.
716729
"""
717-
730+
# This is pretty rough but works to communicate the idea.
731+
# Consider having failed tests raise something like ValidationError rather
732+
# than AssertionErrors.
718733
ds = storage.read(x)
719-
validate_dataset(ds, var, data_type, time_period)
734+
735+
# These only read in Zarr Store metadata -- not memory intensive.
736+
test_variable_names(ds, var)
737+
test_timesteps(ds, data_type, time_period)
738+
739+
# Other test are done on annual selections with dask.delayed to
740+
# avoid large memory errors.
741+
# Doing all this here because this involves storage and I/O logic.
742+
@dask.delayed
743+
def memory_intensive_tests(f, v, t):
744+
d = storage.read(f).sel(time=str(t))
745+
746+
test_for_nans(d, v)
747+
748+
if v == "tasmin":
749+
test_temp_range(d, v)
750+
elif v == "tasmax":
751+
test_temp_range(d, v)
752+
elif v == "dtr":
753+
test_dtr_range(d, v, data_type)
754+
test_negative_values(d, v)
755+
elif v == "pr":
756+
test_negative_values(d, v)
757+
test_maximum_precip(d, v)
758+
else:
759+
raise ValueError(f"Argument {v=} not recognized")
760+
761+
# Assumes error thrown if had problem before this.
762+
return True
763+
764+
tasks = []
765+
for t in set(ds["time"].dt.year.data):
766+
logger.debug(f"Validating year {t}")
767+
tasks.append(memory_intensive_tests(x, var, t))
768+
tasks = dask.compute(*tasks)
769+
assert all(tasks) # Likely don't need this

0 commit comments

Comments
 (0)