Skip to content

Commit 6474ab5

Browse files
committed
Add xarray Dataset
1 parent 0169cc2 commit 6474ab5

File tree

5 files changed

+115
-4
lines changed

5 files changed

+115
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ conda install -c conda-forge traittypes
4040

4141
## Usage
4242

43-
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes and pandas series.
43+
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes, pandas series, and xarray datasets.
4444
- `traittypes` works around some limitations with numpy array comparison to only trigger change events when necessary.
4545
- `traittypes` also extends the traitlets API for adding custom validators to constained proposed values for the attribute.
4646

docs/source/api_documentation.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ The ``DataFrame`` trait type holds a pandas DataFrame.
2020
The ``Series`` trait type holds a pandas Series.
2121

2222
.. autoclass:: traittypes.traittypes.Series
23+
24+
The ``Dataset`` trait type holds an xarray Dataset.
25+
26+
.. autoclass:: traittypes.traittypes.Dataset

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
'test': [
8686
'numpy',
8787
'pandas',
88+
'xarray',
8889
'pytest', # traitlets[test] require this
8990
]
9091
}

traittypes/tests/test_traittypes.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from unittest import TestCase
88
from traitlets import HasTraits, TraitError, observe, Undefined
99
from traitlets.tests.test_traitlets import TraitTestBase
10-
from traittypes import Array, DataFrame, Series
10+
from traittypes import Array, DataFrame, Series, Dataset
1111
import numpy as np
1212
import pandas as pd
13+
import xarray as xr
1314

1415

1516
# Good / Bad value trait test cases
@@ -178,4 +179,41 @@ class Foo(HasTraits):
178179
foo = Foo()
179180
with self.assertRaises(TraitError):
180181
foo.bar = None
181-
foo.baz = None
182+
foo.baz = None
183+
184+
185+
class TestDataset(TestCase):
186+
187+
def test_ds_equal(self):
188+
notifications = []
189+
class Foo(HasTraits):
190+
bar = Dataset({'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14})
191+
@observe('bar')
192+
def _(self, change):
193+
notifications.append(change)
194+
foo = Foo()
195+
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14}
196+
self.assertEqual(notifications, [])
197+
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.15}
198+
self.assertEqual(len(notifications), 1)
199+
200+
def test_initial_values(self):
201+
class Foo(HasTraits):
202+
a = Dataset()
203+
b = Dataset(None, allow_none=True)
204+
c = Dataset([])
205+
d = Dataset(Undefined)
206+
foo = Foo()
207+
self.assertTrue(foo.a.equals(xr.Dataset()))
208+
self.assertTrue(foo.b is None)
209+
self.assertTrue(foo.c.equals(xr.Dataset([])))
210+
self.assertTrue(foo.d is Undefined)
211+
212+
def test_allow_none(self):
213+
class Foo(HasTraits):
214+
bar = Dataset()
215+
baz = Dataset(allow_none=True)
216+
foo = Foo()
217+
with self.assertRaises(TraitError):
218+
foo.bar = None
219+
foo.baz = None

traittypes/traittypes.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __getattribute__(self, name):
1919
import pandas as pd
2020
except ImportError:
2121
pd = _DelayedImportError('pandas')
22+
try:
23+
import xarray as xr
24+
except ImportError:
25+
xr = _DelayedImportError('xarray')
2226

2327

2428
Empty = Sentinel('Empty', 'traittypes',
@@ -30,7 +34,7 @@ def __getattribute__(self, name):
3034

3135
class SciType(TraitType):
3236

33-
"""A base trait type for numpy arrays, pandas dataframes and series."""
37+
"""A base trait type for numpy arrays, pandas dataframes, pandas series and xarray datasets."""
3438

3539
def __init__(self, **kwargs):
3640
super(SciType, self).__init__(**kwargs)
@@ -206,3 +210,67 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
206210
kwargs['klass'] = pd.Series
207211
super(Series, self).__init__(
208212
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
213+
214+
215+
class XarrayType(SciType):
216+
217+
"""An xarray dataset trait type."""
218+
219+
info_text = 'an xarray dataset'
220+
221+
klass = None
222+
223+
def validate(self, obj, value):
224+
if value is None and not self.allow_none:
225+
self.error(obj, value)
226+
if value is None or value is Undefined:
227+
return super(XarrayType, self).validate(obj, value)
228+
try:
229+
value = self.klass(value)
230+
except (ValueError, TypeError) as e:
231+
raise TraitError(e)
232+
return super(XarrayType, self).validate(obj, value)
233+
234+
def set(self, obj, value):
235+
new_value = self._validate(obj, value)
236+
old_value = obj._trait_values.get(self.name, self.default_value)
237+
obj._trait_values[self.name] = new_value
238+
if ((old_value is None and new_value is not None) or
239+
(old_value is Undefined and new_value is not Undefined) or
240+
not old_value.equals(new_value)):
241+
obj._notify_trait(self.name, old_value, new_value)
242+
243+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
244+
if klass is None:
245+
klass = self.klass
246+
if (klass is not None) and inspect.isclass(klass):
247+
self.klass = klass
248+
else:
249+
raise TraitError('The klass attribute must be a class'
250+
' not: %r' % klass)
251+
self.dtype = dtype
252+
if default_value is Empty:
253+
default_value = klass()
254+
elif default_value is not None and default_value is not Undefined:
255+
default_value = klass(default_value)
256+
super(XarrayType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
257+
258+
def make_dynamic_default(self):
259+
if self.default_value is None or self.default_value is Undefined:
260+
return self.default_value
261+
else:
262+
return self.default_value.copy()
263+
264+
265+
class Dataset(XarrayType):
266+
267+
"""An xarray dataset trait type."""
268+
269+
info_text = 'an xarray dataset'
270+
271+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
272+
if 'klass' not in kwargs and self.klass is None:
273+
import xarray as xr
274+
kwargs['klass'] = xr.Dataset
275+
super(Dataset, self).__init__(
276+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)

0 commit comments

Comments
 (0)