Skip to content

Commit ba6f6b6

Browse files
committed
Handle masked variables with non-float types
Depending on the variable type, the masked value will be one of: * np.nan - for floats * np.nan + j * np.nan - for complex numbers * np.datetime64('nat') - for datetimes * np.timedelta64('nat') - for timedeltas The change makes emsarray use `xarray.core.dtypes.maybe_promote` to determine the correct mask value, preventing similar future issues if the fill value rules in xarray change.
1 parent baf799f commit ba6f6b6

File tree

8 files changed

+108
-44
lines changed

8 files changed

+108
-44
lines changed

docs/releases/development.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
Next release (in development)
33
=============================
44

5-
* ...
5+
* Fixed an issue with ``_FillValue`` / ``missing_value``
6+
and variables with non-float types such as ``timedelta64``
7+
(:pr:`71`)

src/emsarray/masking.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import xarray as xr
17+
from xarray.core.dtypes import maybe_promote
1718

1819
from emsarray import utils
1920
from emsarray.types import Pathish
@@ -71,7 +72,6 @@ def mask_grid_dataset(
7172
# file system, at the added expense of having to recombine the dataset
7273
# afterwards.
7374
for key, data_array in dataset.data_vars.items():
74-
logger.debug("DataArray %s", key)
7575
masked_data_array = mask_grid_data_array(mask, data_array)
7676
variable_path = work_path / f"{key}.nc"
7777
mfdataset_names.append(variable_path)
@@ -130,19 +130,28 @@ def mask_grid_data_array(mask: xr.Dataset, data_array: xr.DataArray) -> xr.DataA
130130
try:
131131
fill_value = find_fill_value(data_array)
132132
except ValueError:
133+
logger.debug(
134+
"Data array %r has no valid fill value, leaving as is",
135+
data_array.name)
133136
return data_array
134137

135138
# Loop through each possible mask
136139
for mask_name, mask_data_array in mask.data_vars.items():
137140
# If every dimension of this mask exists in the data array, apply it
138141
if dimensions >= set(mask_data_array.dims):
142+
logger.debug(
143+
"Masking data array %r with mask %r",
144+
data_array.name, mask_name)
139145
new_data_array = cast(xr.DataArray, data_array.where(mask_data_array, other=fill_value))
140146
new_data_array.attrs = data_array.attrs
141147
new_data_array.encoding = data_array.encoding
142148
return new_data_array
143149

144150
# Fallback, no appropriate mask was found, so don't apply any.
145151
# This generally happens for data arrays such as time, record, x_grid, etc.
152+
logger.debug(
153+
"Data array %r had no relevant mask, leaving as is",
154+
data_array.name)
146155
return data_array
147156

148157

@@ -182,24 +191,16 @@ def find_fill_value(data_array: xr.DataArray) -> Any:
182191
# constructed a dataset using one...
183192
return np.ma.masked
184193

185-
if '_FillValue' in data_array.encoding:
186-
# The dataset was opened with mask_and_scale=True and a mask has been
187-
# applied. Masked values are now represented as np.nan, not _FillValue.
188-
return np.nan
189-
190-
if '_FillValue' in data_array.attrs:
191-
# The dataset was opened with mask_and_scale=False and a mask has not
192-
# been applied. Masked values should be represented using _FillValue.
193-
return data_array.attrs['_FillValue']
194-
195-
if issubclass(data_array.dtype.type, np.floating):
196-
# NaN is a useful fallback for a _FillValue, but only if the dtype
197-
# is some sort of float. We won't actually _set_ a _FillValue
198-
# attribute though, as that can play havok when trying to save
199-
# existing datasets. xarray gets real grumpy when you have
200-
# a _FillValue and a missing_value, and some existing datasets play
201-
# fast and loose with mixing the two.
202-
return np.nan
194+
attrs = ['_FillValue', 'missing_value']
195+
for attr in attrs:
196+
if attr in data_array.attrs:
197+
# The dataset was opened with mask_and_scale=False and a mask has not
198+
# been applied. Masked values should be represented using _FillValue/missing_value.
199+
return data_array.attrs[attr]
200+
201+
promoted_dtype, fill_value = maybe_promote(data_array.dtype)
202+
if promoted_dtype == data_array.dtype:
203+
return fill_value
203204

204205
raise ValueError("No appropriate fill value found")
205206

src/emsarray/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from packaging.version import Version
3232
from xarray.coding import times
3333
from xarray.core.common import contains_cftime_datetimes
34+
from xarray.core.dtypes import maybe_promote
3435

3536
from emsarray.types import Pathish
3637

@@ -233,8 +234,10 @@ def disable_default_fill_value(dataset_or_array: Union[xr.Dataset, xr.DataArray]
233234
The :class:`xarray.Dataset` or :class:`xarray.DataArray` to update
234235
"""
235236
for variable in _get_variables(dataset_or_array):
237+
current_dtype = variable.dtype
238+
promoted_dtype, fill_value = maybe_promote(current_dtype)
236239
if (
237-
issubclass(variable.dtype.type, np.floating)
240+
current_dtype == promoted_dtype
238241
and "_FillValue" not in variable.encoding
239242
and "_FillValue" not in variable.attrs
240243
):

tests/datasets/masking/find_fill_value/make_datasets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ def make_float_with_fill_value_and_offset(
5252
ds.close()
5353

5454

55+
def make_timedelta_with_missing_value(
56+
output_path: pathlib.Path = here / "timedelta_with_missing_value.nc",
57+
) -> None:
58+
ds = netCDF4.Dataset(output_path, "w", "NETCDF4")
59+
ds.createDimension("x", 2)
60+
ds.createDimension("y", 2)
61+
62+
missing_value = np.float32(1.e+35)
63+
var = ds.createVariable("var", "f4", ["y", "x"], fill_value=False)
64+
var.missing_value = missing_value
65+
var.units = "days"
66+
var[:] = np.arange(4).reshape((2, 2))
67+
var[1, 1] = missing_value
68+
69+
ds.close()
5570

5671

5772
def make_int_with_fill_value_and_offset(
@@ -73,4 +88,5 @@ def make_int_with_fill_value_and_offset(
7388
if __name__ == '__main__':
7489
make_float_with_fill_value()
7590
make_float_with_fill_value_and_offset()
91+
make_timedelta_with_missing_value()
7692
make_int_with_fill_value_and_offset()
Binary file not shown.

tests/masking/test_mask_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ def test_mask_dataset(tmp_path: pathlib.Path):
155155
data=np.random.normal(0, 0.2, (records, j_size, i_size)),
156156
dims=["record", "j_centre", "i_centre"],
157157
attrs={
158-
"units": "metre",
159-
"long_name": "Surface elevation",
160-
"standard_name": "sea_surface_height_above_geoid",
161-
}
158+
"units": "metre",
159+
"long_name": "Surface elevation",
160+
"standard_name": "sea_surface_height_above_geoid",
161+
}
162162
)
163163
temp = xr.DataArray(
164164
data=np.random.normal(12, 0.5, (records, k_size, j_size, i_size)),
@@ -262,7 +262,7 @@ def test_mask_dataset(tmp_path: pathlib.Path):
262262
assert nc_flag2.shape == (k_size, 4, 3)
263263
flag2_mask = np.stack([np.array([
264264
[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]
265-
])]*k_size).astype(bool)
265+
])] * k_size).astype(bool)
266266
expected: np.ndarray = np.ma.masked_array(
267267
flag2.values[:, 1:5, 1:4].copy(),
268268
mask=flag2_mask,

tests/masking/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from numpy.testing import assert_equal
1010

1111
from emsarray import masking
12+
from emsarray.utils import to_netcdf_with_fixes
1213
from tests.utils import mask_from_strings
1314

1415

@@ -102,6 +103,26 @@ def test_find_fill_value_masked_and_scaled_int(datasets):
102103
assert_dtype_equal(masking.find_fill_value(data_array), np.int8(-1))
103104

104105

106+
def test_find_fill_value_timedelta_with_missing_value(
107+
datasets: pathlib.Path,
108+
tmp_path: pathlib.Path,
109+
) -> None:
110+
dataset_path = datasets / 'masking/find_fill_value/timedelta_with_missing_value.nc'
111+
112+
missing_value = np.float32(1.e35)
113+
assert_raw_values(
114+
dataset_path, 'var',
115+
np.array([[0, 1], [2, missing_value]], dtype=np.float32))
116+
117+
with xr.open_dataset(dataset_path) as dataset:
118+
data_array = dataset['var']
119+
assert dataset['var'].dtype == np.dtype('timedelta64[ns]')
120+
fill_value = masking.find_fill_value(data_array)
121+
assert np.isnat(fill_value)
122+
123+
to_netcdf_with_fixes(dataset, tmp_path / 'dataset.nc')
124+
125+
105126
def test_calculate_mask_bounds():
106127
mask = xr.Dataset(
107128
data_vars={

tests/test_utils.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,45 +57,66 @@ def test_fix_time_units_for_ems(tmp_path: pathlib.Path):
5757

5858

5959
def test_disable_default_fill_value(tmp_path: pathlib.Path):
60-
foo = xarray.DataArray(
60+
int_var = xarray.DataArray(
6161
data=np.arange(35, dtype=int).reshape(5, 7),
6262
dims=['j', 'i'],
6363
attrs={"Hello": "World"},
6464
)
65-
bar = xarray.DataArray(
66-
data=np.arange(35, dtype=np.float64).reshape(5, 7),
67-
dims=['j', 'i'],
68-
)
69-
baz = xarray.DataArray(
70-
data=np.arange(35, dtype=np.float64).reshape(5, 7),
71-
dims=['j', 'i'],
72-
)
73-
baz.data = np.where(np.tri(5, 7, dtype=bool), baz.data, np.nan)
74-
baz.encoding["_FillValue"] = np.nan
7565

76-
dataset = xarray.Dataset(data_vars={"foo": foo, "bar": bar, "baz": baz})
66+
float_var = xarray.DataArray(
67+
data=np.arange(35, dtype=np.float64).reshape(5, 7),
68+
dims=['j', 'i'])
69+
70+
f_data = np.where(
71+
np.tri(5, 7, dtype=bool),
72+
np.arange(35, dtype=np.float64).reshape(5, 7),
73+
np.nan)
74+
float_with_fill_value_var = xarray.DataArray(data=f_data, dims=['j', 'i'])
75+
float_with_fill_value_var.encoding["_FillValue"] = np.nan
76+
77+
td_data = np.where(
78+
np.tri(5, 7, dtype=bool),
79+
np.arange(35).reshape(5, 7) * np.timedelta64(1, 'D'),
80+
np.timedelta64('nat'))
81+
timedelta_with_missing_value_var = xarray.DataArray(
82+
data=td_data, dims=['j', 'i'])
83+
timedelta_with_missing_value_var.encoding['missing_value'] = np.float64('1e35')
84+
timedelta_with_missing_value_var.encoding['units'] = 'days'
85+
86+
dataset = xarray.Dataset(data_vars={
87+
"int_var": int_var,
88+
"float_var": float_var,
89+
"float_with_fill_value_var": float_with_fill_value_var,
90+
"timedelta_with_missing_value_var": timedelta_with_missing_value_var,
91+
})
7792

7893
# Save to a netCDF4 and then prove that it is bad
7994
dataset.to_netcdf(tmp_path / "bad.nc")
8095
with netCDF4.Dataset(tmp_path / "bad.nc", "r") as nc_dataset:
8196
# This one shouldn't be here because it is an integer datatype. xarray
8297
# does the right thing already in this case.
83-
assert '_FillValue' not in nc_dataset.variables["foo"].ncattrs()
98+
assert '_FillValue' not in nc_dataset.variables["int_var"].ncattrs()
8499
# This one shouldn't be here as we didnt set it, and the array is full!
85100
# This is the problem we are trying to solve
86-
assert np.isnan(nc_dataset.variables["bar"].getncattr("_FillValue"))
101+
assert np.isnan(nc_dataset.variables["float_var"].getncattr("_FillValue"))
87102
# This one is quite alright, we did explicitly set it after all
88-
assert np.isnan(nc_dataset.variables["baz"].getncattr("_FillValue"))
103+
assert np.isnan(nc_dataset.variables["float_with_fill_value_var"].getncattr("_FillValue"))
104+
# This one is incorrect, a `missing_value` attribute has already been set
105+
assert np.isnan(nc_dataset.variables["timedelta_with_missing_value_var"].getncattr("_FillValue"))
89106

90107
utils.disable_default_fill_value(dataset)
91108
dataset.to_netcdf(tmp_path / "good.nc")
92109
with netCDF4.Dataset(tmp_path / "good.nc", "r") as nc_dataset:
93110
# This one should still be unset
94-
assert '_FillValue' not in nc_dataset.variables["foo"].ncattrs()
111+
assert '_FillValue' not in nc_dataset.variables["int_var"].ncattrs()
95112
# This one should now be unset
96-
assert '_FillValue' not in nc_dataset.variables["bar"].ncattrs()
113+
assert '_FillValue' not in nc_dataset.variables["float_var"].ncattrs()
97114
# Make sure this didn't get clobbered
98-
assert np.isnan(nc_dataset.variables["baz"].getncattr("_FillValue"))
115+
assert np.isnan(nc_dataset.variables["float_with_fill_value_var"].getncattr("_FillValue"))
116+
# This one should now be unset
117+
nc_timedelta = nc_dataset.variables["timedelta_with_missing_value_var"]
118+
assert '_FillValue' not in nc_timedelta.ncattrs()
119+
assert nc_timedelta.getncattr('missing_value') == np.float64('1e35')
99120

100121

101122
def test_dataset_like():

0 commit comments

Comments
 (0)