Skip to content

Commit bcd8b6e

Browse files
committed
implement coders, adapt tests
1 parent 1c81162 commit bcd8b6e

File tree

3 files changed

+175
-130
lines changed

3 files changed

+175
-130
lines changed

xarray/coding/variables.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,72 @@ def __repr__(self) -> str:
7878
)
7979

8080

81+
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
82+
"""Decode arrays on the fly from non-native to native endianness
83+
84+
This is useful for decoding arrays from netCDF3 files (which are all
85+
big endian) into native endianness, so they can be used with Cython
86+
functions, such as those found in bottleneck and pandas.
87+
88+
>>> x = np.arange(5, dtype=">i2")
89+
90+
>>> x.dtype
91+
dtype('>i2')
92+
93+
>>> NativeEndiannessArray(x).dtype
94+
dtype('int16')
95+
96+
>>> indexer = indexing.BasicIndexer((slice(None),))
97+
>>> NativeEndiannessArray(x)[indexer].dtype
98+
dtype('int16')
99+
"""
100+
101+
__slots__ = ("array",)
102+
103+
def __init__(self, array):
104+
self.array = indexing.as_indexable(array)
105+
106+
@property
107+
def dtype(self):
108+
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
109+
110+
def __getitem__(self, key):
111+
return np.asarray(self.array[key], dtype=self.dtype)
112+
113+
114+
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
115+
"""Decode arrays on the fly from integer to boolean datatype
116+
117+
This is useful for decoding boolean arrays from integer typed netCDF
118+
variables.
119+
120+
>>> x = np.array([1, 0, 1, 1, 0], dtype="i1")
121+
122+
>>> x.dtype
123+
dtype('int8')
124+
125+
>>> BoolTypeArray(x).dtype
126+
dtype('bool')
127+
128+
>>> indexer = indexing.BasicIndexer((slice(None),))
129+
>>> BoolTypeArray(x)[indexer].dtype
130+
dtype('bool')
131+
"""
132+
133+
__slots__ = ("array",)
134+
135+
def __init__(self, array):
136+
self.array = indexing.as_indexable(array)
137+
138+
@property
139+
def dtype(self):
140+
return np.dtype("bool")
141+
142+
def __getitem__(self, key):
143+
return np.asarray(self.array[key], dtype=self.dtype)
144+
145+
146+
81147
def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike):
82148
"""Lazily apply an element-wise function to an array.
83149
Parameters
@@ -349,3 +415,99 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
349415
return Variable(dims, data, attrs, encoding, fastpath=True)
350416
else:
351417
return variable
418+
419+
420+
class DefaultFillvalueCoder(VariableCoder):
421+
"""Encode default _FillValue if needed."""
422+
423+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
424+
dims, data, attrs, encoding = unpack_for_encoding(variable)
425+
# make NaN the fill value for float types
426+
if (
427+
"_FillValue" not in attrs
428+
and "_FillValue" not in encoding
429+
and np.issubdtype(variable.dtype, np.floating)
430+
):
431+
attrs["_FillValue"] = variable.dtype.type(np.nan)
432+
return Variable(dims, data, attrs, encoding, fastpath=True)
433+
else:
434+
return variable
435+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
436+
raise NotImplementedError()
437+
438+
439+
class BooleanCoder(VariableCoder):
440+
"""Code boolean values."""
441+
442+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
443+
if (
444+
(variable.dtype == bool)
445+
and ("dtype" not in variable.encoding)
446+
and ("dtype" not in variable.attrs)
447+
):
448+
dims, data, attrs, encoding = unpack_for_encoding(variable)
449+
attrs["dtype"] = "bool"
450+
data = duck_array_ops.astype(data, dtype="i1", copy=True)
451+
452+
return Variable(dims, data, attrs, encoding, fastpath=True)
453+
else:
454+
return variable
455+
456+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
457+
if variable.attrs.get("dtype", False) == "bool":
458+
dims, data, attrs, encoding = unpack_for_decoding(variable)
459+
del attrs["dtype"]
460+
data = BoolTypeArray(data)
461+
return Variable(dims, data, attrs, encoding, fastpath=True)
462+
else:
463+
return variable
464+
465+
466+
class EndianCoder(VariableCoder):
467+
"""Decode Endianness to native."""
468+
469+
def encode(self):
470+
raise NotImplementedError()
471+
472+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
473+
dims, data, attrs, encoding = unpack_for_decoding(variable)
474+
if not data.dtype.isnative:
475+
data = NativeEndiannessArray(data)
476+
return Variable(dims, data, attrs, encoding, fastpath=True)
477+
else:
478+
return variable
479+
480+
481+
class NonStringCoder(VariableCoder):
482+
"""Encode NonString variables if dtypes differ."""
483+
484+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
485+
if "dtype" in variable.encoding and variable.encoding["dtype"] not in (
486+
"S1",
487+
str,
488+
):
489+
dims, data, attrs, encoding = unpack_for_encoding(variable)
490+
dtype = np.dtype(encoding.pop("dtype"))
491+
if dtype != variable.dtype:
492+
if np.issubdtype(dtype, np.integer):
493+
if (
494+
np.issubdtype(variable.dtype, np.floating)
495+
and "_FillValue" not in variable.attrs
496+
and "missing_value" not in variable.attrs
497+
):
498+
warnings.warn(
499+
f"saving variable {name} with floating "
500+
"point data as an integer dtype without "
501+
"any _FillValue to use for NaNs",
502+
SerializationWarning,
503+
stacklevel=10,
504+
)
505+
data = np.around(data)
506+
data = data.astype(dtype=dtype)
507+
return Variable(dims, data, attrs, encoding, fastpath=True)
508+
else:
509+
return variable
510+
511+
def decode(self):
512+
raise NotImplementedError()
513+

xarray/conventions.py

Lines changed: 10 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -48,123 +48,10 @@
4848
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]
4949

5050

51-
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
52-
"""Decode arrays on the fly from non-native to native endianness
53-
54-
This is useful for decoding arrays from netCDF3 files (which are all
55-
big endian) into native endianness, so they can be used with Cython
56-
functions, such as those found in bottleneck and pandas.
57-
58-
>>> x = np.arange(5, dtype=">i2")
59-
60-
>>> x.dtype
61-
dtype('>i2')
62-
63-
>>> NativeEndiannessArray(x).dtype
64-
dtype('int16')
65-
66-
>>> indexer = indexing.BasicIndexer((slice(None),))
67-
>>> NativeEndiannessArray(x)[indexer].dtype
68-
dtype('int16')
69-
"""
70-
71-
__slots__ = ("array",)
72-
73-
def __init__(self, array):
74-
self.array = indexing.as_indexable(array)
75-
76-
@property
77-
def dtype(self):
78-
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
79-
80-
def __getitem__(self, key):
81-
return np.asarray(self.array[key], dtype=self.dtype)
82-
83-
84-
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
85-
"""Decode arrays on the fly from integer to boolean datatype
86-
87-
This is useful for decoding boolean arrays from integer typed netCDF
88-
variables.
89-
90-
>>> x = np.array([1, 0, 1, 1, 0], dtype="i1")
91-
92-
>>> x.dtype
93-
dtype('int8')
94-
95-
>>> BoolTypeArray(x).dtype
96-
dtype('bool')
97-
98-
>>> indexer = indexing.BasicIndexer((slice(None),))
99-
>>> BoolTypeArray(x)[indexer].dtype
100-
dtype('bool')
101-
"""
102-
103-
__slots__ = ("array",)
104-
105-
def __init__(self, array):
106-
self.array = indexing.as_indexable(array)
107-
108-
@property
109-
def dtype(self):
110-
return np.dtype("bool")
111-
112-
def __getitem__(self, key):
113-
return np.asarray(self.array[key], dtype=self.dtype)
114-
115-
11651
def _var_as_tuple(var: Variable) -> T_VarTuple:
11752
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
11853

11954

120-
def maybe_encode_nonstring_dtype(var: Variable, name: T_Name = None) -> Variable:
121-
if "dtype" in var.encoding and var.encoding["dtype"] not in ("S1", str):
122-
dims, data, attrs, encoding = _var_as_tuple(var)
123-
dtype = np.dtype(encoding.pop("dtype"))
124-
if dtype != var.dtype:
125-
if np.issubdtype(dtype, np.integer):
126-
if (
127-
np.issubdtype(var.dtype, np.floating)
128-
and "_FillValue" not in var.attrs
129-
and "missing_value" not in var.attrs
130-
):
131-
warnings.warn(
132-
f"saving variable {name} with floating "
133-
"point data as an integer dtype without "
134-
"any _FillValue to use for NaNs",
135-
SerializationWarning,
136-
stacklevel=10,
137-
)
138-
data = np.around(data)
139-
data = data.astype(dtype=dtype)
140-
var = Variable(dims, data, attrs, encoding, fastpath=True)
141-
return var
142-
143-
144-
def maybe_default_fill_value(var: Variable) -> Variable:
145-
# make NaN the fill value for float types:
146-
if (
147-
"_FillValue" not in var.attrs
148-
and "_FillValue" not in var.encoding
149-
and np.issubdtype(var.dtype, np.floating)
150-
):
151-
var.attrs["_FillValue"] = var.dtype.type(np.nan)
152-
return var
153-
154-
155-
def maybe_encode_bools(var: Variable) -> Variable:
156-
if (
157-
(var.dtype == bool)
158-
and ("dtype" not in var.encoding)
159-
and ("dtype" not in var.attrs)
160-
):
161-
dims, data, attrs, encoding = _var_as_tuple(var)
162-
attrs["dtype"] = "bool"
163-
data = duck_array_ops.astype(data, dtype="i1", copy=True)
164-
var = Variable(dims, data, attrs, encoding, fastpath=True)
165-
return var
166-
167-
16855
def _infer_dtype(array, name: T_Name = None) -> np.dtype:
16956
"""Given an object array with no missing values, infer its dtype from its
17057
first element
@@ -292,13 +179,13 @@ def encode_cf_variable(
292179
variables.CFScaleOffsetCoder(),
293180
variables.CFMaskCoder(),
294181
variables.UnsignedIntegerCoder(),
182+
variables.NonStringCoder(),
183+
variables.DefaultFillvalueCoder(),
184+
variables.BooleanCoder(),
295185
]:
296186
var = coder.encode(var, name=name)
297187

298-
# TODO(shoyer): convert all of these to use coders, too:
299-
var = maybe_encode_nonstring_dtype(var, name=name)
300-
var = maybe_default_fill_value(var)
301-
var = maybe_encode_bools(var)
188+
# TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
302189
var = ensure_dtype_not_object(var, name=name)
303190

304191
for attr_name in CF_RELATED_DATA:
@@ -389,19 +276,15 @@ def decode_cf_variable(
389276
if decode_times:
390277
var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name)
391278

392-
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
393-
# TODO(shoyer): convert everything below to use coders
279+
if decode_endianness and not var.dtype.isnative:
280+
var = variables.EndianCoder().decode(var)
281+
original_dtype = var.dtype
394282

395-
if decode_endianness and not data.dtype.isnative:
396-
# do this last, so it's only done if we didn't already unmask/scale
397-
data = NativeEndiannessArray(data)
398-
original_dtype = data.dtype
283+
var = variables.BooleanCoder().decode(var)
399284

400-
encoding.setdefault("dtype", original_dtype)
285+
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
401286

402-
if "dtype" in attributes and attributes["dtype"] == "bool":
403-
del attributes["dtype"]
404-
data = BoolTypeArray(data)
287+
encoding.setdefault("dtype", original_dtype)
405288

406289
if not is_duck_dask_array(data):
407290
data = indexing.LazilyIndexedArray(data)

xarray/tests/test_conventions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class TestBoolTypeArray:
3333
def test_booltype_array(self) -> None:
3434
x = np.array([1, 0, 1, 1, 0], dtype="i1")
35-
bx = conventions.BoolTypeArray(x)
35+
bx = coding.variables.BoolTypeArray(x)
3636
assert bx.dtype == bool
3737
assert_array_equal(bx, np.array([True, False, True, True, False], dtype=bool))
3838

@@ -41,7 +41,7 @@ class TestNativeEndiannessArray:
4141
def test(self) -> None:
4242
x = np.arange(5, dtype=">i8")
4343
expected = np.arange(5, dtype="int64")
44-
a = conventions.NativeEndiannessArray(x)
44+
a = coding.variables.NativeEndiannessArray(x)
4545
assert a.dtype == expected.dtype
4646
assert a.dtype == expected[:].dtype
4747
assert_array_equal(a, expected)
@@ -247,7 +247,7 @@ def test_decode_coordinates(self) -> None:
247247
def test_0d_int32_encoding(self) -> None:
248248
original = Variable((), np.int32(0), encoding={"dtype": "int64"})
249249
expected = Variable((), np.int64(0))
250-
actual = conventions.maybe_encode_nonstring_dtype(original)
250+
actual = coding.variables.NonStringCoder().encode(original)
251251
assert_identical(expected, actual)
252252

253253
def test_decode_cf_with_multiple_missing_values(self) -> None:

0 commit comments

Comments
 (0)