Skip to content

Commit 361cc96

Browse files
committed
Fix issue with masked timedelta values
1 parent e1c6dcd commit 361cc96

File tree

5 files changed

+100
-14
lines changed

5 files changed

+100
-14
lines changed

src/emsarray/conventions/ugrid.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,11 @@ def _to_index_array(
491491
# If a data array has a fill value, xarray will convert that data array
492492
# to a floating point data type, and replace masked values with np.nan.
493493
# Here we convert a floating point array to a masked integer array.
494-
values = np.ma.masked_invalid(values).astype(np.int_)
494+
masked_values = np.ma.masked_invalid(values)
495+
# numpy will emit a warning when converting an array with np.nan to int,
496+
# even if the nans are masked out.
497+
masked_values.data[masked_values.mask] = self.sensible_fill_value
498+
values = masked_values.astype(self.sensible_dtype)
495499
elif '_FillValue' in data_array.attrs:
496500
# The data array has a fill value, but xarray has not applied it.
497501
# This implied the dataset was opened with mask_and_scale=False,
@@ -1195,13 +1199,17 @@ def apply_clip_mask(self, clip_mask: xr.Dataset, work_dir: Pathish) -> xr.Datase
11951199
# any changes.
11961200
topology_variables: List[xr.DataArray] = [topology.mesh_variable]
11971201

1202+
# This is the fill value used in the mask.
1203+
new_fill_value = clip_mask.data_vars['new_node_index'].encoding['_FillValue']
1204+
11981205
def integer_indices(data_array: xr.DataArray) -> np.ndarray:
1199-
masked = np.ma.masked_invalid(data_array.values)
1200-
masked_integers: np.ndarray = masked.astype(np.int_)
1206+
masked_values = np.ma.masked_invalid(data_array.values)
1207+
# numpy will emit a warning when converting an array with np.nan to int,
1208+
# even if the nans are masked out.
1209+
masked_values.data[masked_values.mask] = new_fill_value
1210+
masked_integers: np.ndarray = masked_values.astype(np.int_)
12011211
return masked_integers
12021212

1203-
# This is the fill value used in the mask.
1204-
new_fill_value = clip_mask.data_vars['new_node_index'].encoding['_FillValue']
12051213
new_node_indices = integer_indices(clip_mask.data_vars['new_node_index'])
12061214
new_face_indices = integer_indices(clip_mask.data_vars['new_face_index'])
12071215
has_edges = 'new_edge_index' in clip_mask.data_vars

tests/conventions/test_ugrid.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ConventionViolationError, ConventionViolationWarning
2424
)
2525
from emsarray.operations import geometry
26-
from tests.utils import assert_property_not_cached
26+
from tests.utils import assert_property_not_cached, filter_warning
2727

2828

2929
def make_faces(width: int, height, fill_value: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -239,10 +239,28 @@ def make_dataset(
239239
},
240240
)
241241

242+
one_day = np.timedelta64(1, 'D').astype('timedelta64[ns]')
243+
period = xr.DataArray(
244+
data=np.concatenate([
245+
np.arange(cell_size - 2, dtype=int) * one_day,
246+
[np.timedelta64('nat', 'ns')] * 2,
247+
]),
248+
dims=[face_dimension],
249+
name="period",
250+
attrs={
251+
"long_name": "Some variable counting days",
252+
},
253+
)
254+
period.encoding.update({
255+
"units": "days",
256+
"_FillValue": np.int16(-1),
257+
"dtype": np.dtype('int16'),
258+
})
259+
242260
dataset = xr.Dataset(
243261
data_vars={var.name: var for var in [
244262
mesh, face_node_connectivity, node_x, node_y,
245-
t, z, botz, eta, temp
263+
t, z, botz, eta, temp, period,
246264
]},
247265
attrs={
248266
'title': "COMPAS defalt version",
@@ -713,7 +731,22 @@ def test_apply_clip_mask(tmp_path):
713731

714732
def test_make_and_apply_clip_mask(tmp_path):
715733
dataset = make_dataset(width=5)
716-
dataset.ems.to_netcdf(tmp_path / "original.nc")
734+
735+
# When saving a dataset to disk, xarray.coding.times.cast_to_int_if_safe
736+
# will check if it is possible to encode a timedelta64 using integer values
737+
# by casting the values and checking for equality.
738+
# Recent versions of numpy will emit warnings
739+
# when casting a data array with dtype timedelta64 to int
740+
# if it contains NaT (not a time) values.
741+
# xarray will fix this eventually, but for now...
742+
# See https://github.com/pydata/xarray/issues/7942
743+
with filter_warning(
744+
'ignore', category=RuntimeWarning,
745+
message='invalid value encountered in cast',
746+
module=r'xarray\.coding\.times',
747+
):
748+
dataset.ems.to_netcdf(tmp_path / "original.nc")
749+
717750
geometry.write_geojson(dataset, tmp_path / 'original.geojson')
718751

719752
polygon = Polygon([[3.4, 1], [3.4, -1], [6, -1], [6, 1], [3.4, 1]])

tests/masking/test_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from emsarray import masking
1212
from emsarray.utils import to_netcdf_with_fixes
13-
from tests.utils import mask_from_strings
13+
from tests.utils import filter_warning, mask_from_strings
1414

1515

1616
def assert_raw_values(
@@ -120,7 +120,14 @@ def test_find_fill_value_timedelta_with_missing_value(
120120
fill_value = masking.find_fill_value(data_array)
121121
assert np.isnat(fill_value)
122122

123-
to_netcdf_with_fixes(dataset, tmp_path / 'dataset.nc')
123+
# See https://github.com/pydata/xarray/issues/7942
124+
with filter_warning(
125+
'ignore', category=RuntimeWarning,
126+
message='invalid value encountered in cast',
127+
module=r'xarray\.coding\.times',
128+
):
129+
# Write this out for easier debugging purposes
130+
to_netcdf_with_fixes(dataset, tmp_path / 'dataset.nc')
124131

125132

126133
def test_calculate_mask_bounds():

tests/test_utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import logging
23
import pathlib
34

45
import netCDF4
@@ -10,6 +11,9 @@
1011
import xarray.testing
1112

1213
from emsarray import utils
14+
from tests.utils import filter_warning
15+
16+
logger = logging.getLogger(__name__)
1317

1418

1519
@pytest.mark.parametrize(
@@ -76,8 +80,9 @@ def test_disable_default_fill_value(tmp_path: pathlib.Path):
7680

7781
td_data = np.where(
7882
np.tri(5, 7, dtype=bool),
79-
np.arange(35).reshape(5, 7) * np.timedelta64(1, 'D'),
80-
np.timedelta64('nat'))
83+
np.arange(35).reshape(5, 7) * np.timedelta64(1, 'D').astype('timedelta64[ns]'),
84+
np.timedelta64('nat', 'ns'))
85+
logger.info('%r', td_data)
8186
timedelta_with_missing_value_var = xarray.DataArray(
8287
data=td_data, dims=['j', 'i'])
8388
timedelta_with_missing_value_var.encoding['missing_value'] = np.float64('1e35')
@@ -91,7 +96,17 @@ def test_disable_default_fill_value(tmp_path: pathlib.Path):
9196
})
9297

9398
# Save to a netCDF4 and then prove that it is bad
94-
dataset.to_netcdf(tmp_path / "bad.nc")
99+
# This emits warnings in current versions of xarray / numpy.
100+
# See https://github.com/pydata/xarray/issues/7942
101+
with filter_warning(
102+
'default', category=RuntimeWarning,
103+
message='invalid value encountered in cast',
104+
module=r'xarray\.coding\.times',
105+
record=True,
106+
) as ws:
107+
dataset.to_netcdf(tmp_path / "bad.nc")
108+
assert len(ws) == 1
109+
95110
with netCDF4.Dataset(tmp_path / "bad.nc", "r") as nc_dataset:
96111
# This one shouldn't be here because it is an integer datatype. xarray
97112
# does the right thing already in this case.
@@ -105,7 +120,17 @@ def test_disable_default_fill_value(tmp_path: pathlib.Path):
105120
assert np.isnan(nc_dataset.variables["timedelta_with_missing_value_var"].getncattr("_FillValue"))
106121

107122
utils.disable_default_fill_value(dataset)
108-
dataset.to_netcdf(tmp_path / "good.nc")
123+
124+
# See https://github.com/pydata/xarray/issues/7942
125+
with filter_warning(
126+
'default', category=RuntimeWarning,
127+
message='invalid value encountered in cast',
128+
module=r'xarray\.coding\.times',
129+
record=True,
130+
) as ws:
131+
dataset.to_netcdf(tmp_path / "good.nc")
132+
assert len(ws) == 1
133+
109134
with netCDF4.Dataset(tmp_path / "good.nc", "r") as nc_dataset:
110135
# This one should still be unset
111136
assert '_FillValue' not in nc_dataset.variables["int_var"].ncattrs()

tests/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import abc
4+
import contextlib
45
import itertools
6+
import warnings
57
from functools import cached_property
68
from typing import Any, Dict, Hashable, List, Optional, Tuple
79

@@ -14,6 +16,17 @@
1416
)
1517

1618

19+
@contextlib.contextmanager
20+
def filter_warning(*args, record: bool = False, **kwargs):
21+
"""
22+
A shortcut wrapper around warnings.catch_warning()
23+
and warnings.filterwarnings()
24+
"""
25+
with warnings.catch_warnings(record=record) as context:
26+
warnings.filterwarnings(*args, **kwargs)
27+
yield context
28+
29+
1730
def box(minx, miny, maxx, maxy) -> shapely.Polygon:
1831
"""
1932
Make a box, with coordinates going counterclockwise

0 commit comments

Comments
 (0)