Skip to content

Commit 7097fc3

Browse files
committed
Add xarray DataArray
1 parent 8583df4 commit 7097fc3

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ conda install -c conda-forge traittypes
4040

4141
## Usage
4242

43-
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes, pandas series, and xarray datasets.
43+
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays.
4444
- `traittypes` works around some limitations with numpy array comparison to only trigger change events when necessary.
4545
- `traittypes` also extends the traitlets API for adding custom validators to constained proposed values for the attribute.
4646

docs/source/api_documentation.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ The ``Series`` trait type holds a pandas Series.
2424
The ``Dataset`` trait type holds an xarray Dataset.
2525

2626
.. autoclass:: traittypes.traittypes.Dataset
27+
28+
The ``DataArray`` trait type holds an xarray DataArray.
29+
30+
.. autoclass:: traittypes.traittypes.DataArray

traittypes/tests/test_traittypes.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest import TestCase
88
from traitlets import HasTraits, TraitError, observe, Undefined
99
from traitlets.tests.test_traitlets import TraitTestBase
10-
from traittypes import Array, DataFrame, Series, Dataset
10+
from traittypes import Array, DataFrame, Series, Dataset, DataArray
1111
import numpy as np
1212
import pandas as pd
1313
import xarray as xr
@@ -201,12 +201,10 @@ def test_initial_values(self):
201201
class Foo(HasTraits):
202202
a = Dataset()
203203
b = Dataset(None, allow_none=True)
204-
c = Dataset([])
205204
d = Dataset(Undefined)
206205
foo = Foo()
207206
self.assertTrue(foo.a.equals(xr.Dataset()))
208207
self.assertTrue(foo.b is None)
209-
self.assertTrue(foo.c.equals(xr.Dataset([])))
210208
self.assertTrue(foo.d is Undefined)
211209

212210
def test_allow_none(self):
@@ -217,3 +215,29 @@ class Foo(HasTraits):
217215
with self.assertRaises(TraitError):
218216
foo.bar = None
219217
foo.baz = None
218+
219+
220+
class TestDataArray(TestCase):
221+
222+
def test_ds_equal(self):
223+
notifications = []
224+
class Foo(HasTraits):
225+
bar = DataArray([[0, 1], [2, 3]])
226+
@observe('bar')
227+
def _(self, change):
228+
notifications.append(change)
229+
foo = Foo()
230+
foo.bar = [[0, 1], [2, 3]]
231+
self.assertEqual(notifications, [])
232+
foo.bar = [[0, 1], [2, 4]]
233+
self.assertEqual(len(notifications), 1)
234+
235+
def test_initial_values(self):
236+
class Foo(HasTraits):
237+
b = DataArray(None, allow_none=True)
238+
c = DataArray([])
239+
d = DataArray(Undefined)
240+
foo = Foo()
241+
self.assertTrue(foo.b is None)
242+
self.assertTrue(foo.c.equals(xr.DataArray([])))
243+
self.assertTrue(foo.d is Undefined)

traittypes/traittypes.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __getattribute__(self, name):
2626

2727
class SciType(TraitType):
2828

29-
"""A base trait type for numpy arrays, pandas dataframes, pandas series and xarray datasets."""
29+
"""A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays."""
3030

3131
def __init__(self, **kwargs):
3232
super(SciType, self).__init__(**kwargs)
@@ -128,9 +128,9 @@ def make_dynamic_default(self):
128128

129129
class PandasType(SciType):
130130

131-
"""A pandas dataframe trait type."""
131+
"""A pandas dataframe or series trait type."""
132132

133-
info_text = 'a pandas dataframe'
133+
info_text = 'a pandas dataframe or series'
134134

135135
klass = None
136136

@@ -154,15 +154,14 @@ def set(self, obj, value):
154154
not old_value.equals(new_value)):
155155
obj._notify_trait(self.name, old_value, new_value)
156156

157-
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
157+
def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs):
158158
if klass is None:
159159
klass = self.klass
160160
if (klass is not None) and inspect.isclass(klass):
161161
self.klass = klass
162162
else:
163163
raise TraitError('The klass attribute must be a class'
164164
' not: %r' % klass)
165-
self.dtype = dtype
166165
if default_value is Empty:
167166
default_value = klass()
168167
elif default_value is not None and default_value is not Undefined:
@@ -195,20 +194,22 @@ class Series(PandasType):
195194
"""A pandas series trait type."""
196195

197196
info_text = 'a pandas series'
197+
dtype = None
198198

199199
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
200200
if 'klass' not in kwargs and self.klass is None:
201201
import pandas as pd
202202
kwargs['klass'] = pd.Series
203203
super(Series, self).__init__(
204204
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
205+
self.dtype = dtype
205206

206207

207208
class XarrayType(SciType):
208209

209-
"""An xarray dataset trait type."""
210+
"""An xarray dataset or dataarray trait type."""
210211

211-
info_text = 'an xarray dataset'
212+
info_text = 'an xarray dataset or dataarray'
212213

213214
klass = None
214215

@@ -232,15 +233,14 @@ def set(self, obj, value):
232233
not old_value.equals(new_value)):
233234
obj._notify_trait(self.name, old_value, new_value)
234235

235-
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
236+
def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs):
236237
if klass is None:
237238
klass = self.klass
238239
if (klass is not None) and inspect.isclass(klass):
239240
self.klass = klass
240241
else:
241242
raise TraitError('The klass attribute must be a class'
242243
' not: %r' % klass)
243-
self.dtype = dtype
244244
if default_value is Empty:
245245
default_value = klass()
246246
elif default_value is not None and default_value is not Undefined:
@@ -266,3 +266,19 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
266266
kwargs['klass'] = xr.Dataset
267267
super(Dataset, self).__init__(
268268
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
269+
270+
271+
class DataArray(XarrayType):
272+
273+
"""An xarray dataarray trait type."""
274+
275+
info_text = 'an xarray dataarray'
276+
dtype = None
277+
278+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
279+
if 'klass' not in kwargs and self.klass is None:
280+
import xarray as xr
281+
kwargs['klass'] = xr.DataArray
282+
super(DataArray, self).__init__(
283+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
284+
self.dtype = dtype

0 commit comments

Comments
 (0)