Skip to content

Commit b47dad9

Browse files
Merge pull request #27 from davidbrochart/xarray_dataset
Add xarray Dataset
2 parents d5a110d + 7097fc3 commit b47dad9

File tree

5 files changed

+159
-12
lines changed

5 files changed

+159
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ conda install -c conda-forge traittypes
4040

4141
## Usage
4242

43-
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes and pandas series.
43+
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays.
4444
- `traittypes` works around some limitations with numpy array comparison to only trigger change events when necessary.
4545
- `traittypes` also extends the traitlets API for adding custom validators to constained proposed values for the attribute.
4646

docs/source/api_documentation.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,11 @@ The ``DataFrame`` trait type holds a pandas DataFrame.
2020
The ``Series`` trait type holds a pandas Series.
2121

2222
.. autoclass:: traittypes.traittypes.Series
23+
24+
The ``Dataset`` trait type holds an xarray Dataset.
25+
26+
.. autoclass:: traittypes.traittypes.Dataset
27+
28+
The ``DataArray`` trait type holds an xarray DataArray.
29+
30+
.. autoclass:: traittypes.traittypes.DataArray

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
'test': [
8686
'numpy',
8787
'pandas',
88+
'xarray',
8889
'pytest', # traitlets[test] require this
8990
]
9091
}

traittypes/tests/test_traittypes.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from unittest import TestCase
88
from traitlets import HasTraits, TraitError, observe, Undefined
99
from traitlets.tests.test_traitlets import TraitTestBase
10-
from traittypes import Array, DataFrame, Series
10+
from traittypes import Array, DataFrame, Series, Dataset, DataArray
1111
import numpy as np
1212
import pandas as pd
13+
import xarray as xr
1314

1415

1516
# Good / Bad value trait test cases
@@ -178,4 +179,65 @@ class Foo(HasTraits):
178179
foo = Foo()
179180
with self.assertRaises(TraitError):
180181
foo.bar = None
181-
foo.baz = None
182+
foo.baz = None
183+
184+
185+
class TestDataset(TestCase):
186+
187+
def test_ds_equal(self):
188+
notifications = []
189+
class Foo(HasTraits):
190+
bar = Dataset({'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14})
191+
@observe('bar')
192+
def _(self, change):
193+
notifications.append(change)
194+
foo = Foo()
195+
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14}
196+
self.assertEqual(notifications, [])
197+
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.15}
198+
self.assertEqual(len(notifications), 1)
199+
200+
def test_initial_values(self):
201+
class Foo(HasTraits):
202+
a = Dataset()
203+
b = Dataset(None, allow_none=True)
204+
d = Dataset(Undefined)
205+
foo = Foo()
206+
self.assertTrue(foo.a.equals(xr.Dataset()))
207+
self.assertTrue(foo.b is None)
208+
self.assertTrue(foo.d is Undefined)
209+
210+
def test_allow_none(self):
211+
class Foo(HasTraits):
212+
bar = Dataset()
213+
baz = Dataset(allow_none=True)
214+
foo = Foo()
215+
with self.assertRaises(TraitError):
216+
foo.bar = None
217+
foo.baz = None
218+
219+
220+
class TestDataArray(TestCase):
221+
222+
def test_ds_equal(self):
223+
notifications = []
224+
class Foo(HasTraits):
225+
bar = DataArray([[0, 1], [2, 3]])
226+
@observe('bar')
227+
def _(self, change):
228+
notifications.append(change)
229+
foo = Foo()
230+
foo.bar = [[0, 1], [2, 3]]
231+
self.assertEqual(notifications, [])
232+
foo.bar = [[0, 1], [2, 4]]
233+
self.assertEqual(len(notifications), 1)
234+
235+
def test_initial_values(self):
236+
class Foo(HasTraits):
237+
b = DataArray(None, allow_none=True)
238+
c = DataArray([])
239+
d = DataArray(Undefined)
240+
foo = Foo()
241+
self.assertTrue(foo.b is None)
242+
self.assertTrue(foo.c.equals(xr.DataArray([])))
243+
self.assertTrue(foo.d is Undefined)

traittypes/traittypes.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ def __getattribute__(self, name):
1515
import numpy as np
1616
except ImportError:
1717
np = _DelayedImportError('numpy')
18-
try:
19-
import pandas as pd
20-
except ImportError:
21-
pd = _DelayedImportError('pandas')
2218

2319

2420
Empty = Sentinel('Empty', 'traittypes',
@@ -30,7 +26,7 @@ def __getattribute__(self, name):
3026

3127
class SciType(TraitType):
3228

33-
"""A base trait type for numpy arrays, pandas dataframes and series."""
29+
"""A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays."""
3430

3531
def __init__(self, **kwargs):
3632
super(SciType, self).__init__(**kwargs)
@@ -132,9 +128,9 @@ def make_dynamic_default(self):
132128

133129
class PandasType(SciType):
134130

135-
"""A pandas dataframe trait type."""
131+
"""A pandas dataframe or series trait type."""
136132

137-
info_text = 'a pandas dataframe'
133+
info_text = 'a pandas dataframe or series'
138134

139135
klass = None
140136

@@ -158,15 +154,14 @@ def set(self, obj, value):
158154
not old_value.equals(new_value)):
159155
obj._notify_trait(self.name, old_value, new_value)
160156

161-
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
157+
def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs):
162158
if klass is None:
163159
klass = self.klass
164160
if (klass is not None) and inspect.isclass(klass):
165161
self.klass = klass
166162
else:
167163
raise TraitError('The klass attribute must be a class'
168164
' not: %r' % klass)
169-
self.dtype = dtype
170165
if default_value is Empty:
171166
default_value = klass()
172167
elif default_value is not None and default_value is not Undefined:
@@ -199,10 +194,91 @@ class Series(PandasType):
199194
"""A pandas series trait type."""
200195

201196
info_text = 'a pandas series'
197+
dtype = None
202198

203199
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
204200
if 'klass' not in kwargs and self.klass is None:
205201
import pandas as pd
206202
kwargs['klass'] = pd.Series
207203
super(Series, self).__init__(
208204
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
205+
self.dtype = dtype
206+
207+
208+
class XarrayType(SciType):
209+
210+
"""An xarray dataset or dataarray trait type."""
211+
212+
info_text = 'an xarray dataset or dataarray'
213+
214+
klass = None
215+
216+
def validate(self, obj, value):
217+
if value is None and not self.allow_none:
218+
self.error(obj, value)
219+
if value is None or value is Undefined:
220+
return super(XarrayType, self).validate(obj, value)
221+
try:
222+
value = self.klass(value)
223+
except (ValueError, TypeError) as e:
224+
raise TraitError(e)
225+
return super(XarrayType, self).validate(obj, value)
226+
227+
def set(self, obj, value):
228+
new_value = self._validate(obj, value)
229+
old_value = obj._trait_values.get(self.name, self.default_value)
230+
obj._trait_values[self.name] = new_value
231+
if ((old_value is None and new_value is not None) or
232+
(old_value is Undefined and new_value is not Undefined) or
233+
not old_value.equals(new_value)):
234+
obj._notify_trait(self.name, old_value, new_value)
235+
236+
def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs):
237+
if klass is None:
238+
klass = self.klass
239+
if (klass is not None) and inspect.isclass(klass):
240+
self.klass = klass
241+
else:
242+
raise TraitError('The klass attribute must be a class'
243+
' not: %r' % klass)
244+
if default_value is Empty:
245+
default_value = klass()
246+
elif default_value is not None and default_value is not Undefined:
247+
default_value = klass(default_value)
248+
super(XarrayType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
249+
250+
def make_dynamic_default(self):
251+
if self.default_value is None or self.default_value is Undefined:
252+
return self.default_value
253+
else:
254+
return self.default_value.copy()
255+
256+
257+
class Dataset(XarrayType):
258+
259+
"""An xarray dataset trait type."""
260+
261+
info_text = 'an xarray dataset'
262+
263+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
264+
if 'klass' not in kwargs and self.klass is None:
265+
import xarray as xr
266+
kwargs['klass'] = xr.Dataset
267+
super(Dataset, self).__init__(
268+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
269+
270+
271+
class DataArray(XarrayType):
272+
273+
"""An xarray dataarray trait type."""
274+
275+
info_text = 'an xarray dataarray'
276+
dtype = None
277+
278+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
279+
if 'klass' not in kwargs and self.klass is None:
280+
import xarray as xr
281+
kwargs['klass'] = xr.DataArray
282+
super(DataArray, self).__init__(
283+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
284+
self.dtype = dtype

0 commit comments

Comments
 (0)