@@ -26,7 +26,7 @@ def __getattribute__(self, name):
26
26
27
27
class SciType (TraitType ):
28
28
29
- """A base trait type for numpy arrays, pandas dataframes, pandas series and xarray datasets ."""
29
+ """A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays ."""
30
30
31
31
def __init__ (self , ** kwargs ):
32
32
super (SciType , self ).__init__ (** kwargs )
@@ -128,9 +128,9 @@ def make_dynamic_default(self):
128
128
129
129
class PandasType (SciType ):
130
130
131
- """A pandas dataframe trait type."""
131
+ """A pandas dataframe or series trait type."""
132
132
133
- info_text = 'a pandas dataframe'
133
+ info_text = 'a pandas dataframe or series '
134
134
135
135
klass = None
136
136
@@ -154,15 +154,14 @@ def set(self, obj, value):
154
154
not old_value .equals (new_value )):
155
155
obj ._notify_trait (self .name , old_value , new_value )
156
156
157
- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
157
+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
158
158
if klass is None :
159
159
klass = self .klass
160
160
if (klass is not None ) and inspect .isclass (klass ):
161
161
self .klass = klass
162
162
else :
163
163
raise TraitError ('The klass attribute must be a class'
164
164
' not: %r' % klass )
165
- self .dtype = dtype
166
165
if default_value is Empty :
167
166
default_value = klass ()
168
167
elif default_value is not None and default_value is not Undefined :
@@ -195,20 +194,22 @@ class Series(PandasType):
195
194
"""A pandas series trait type."""
196
195
197
196
info_text = 'a pandas series'
197
+ dtype = None
198
198
199
199
def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
200
200
if 'klass' not in kwargs and self .klass is None :
201
201
import pandas as pd
202
202
kwargs ['klass' ] = pd .Series
203
203
super (Series , self ).__init__ (
204
204
default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
205
+ self .dtype = dtype
205
206
206
207
207
208
class XarrayType (SciType ):
208
209
209
- """An xarray dataset trait type."""
210
+ """An xarray dataset or dataarray trait type."""
210
211
211
- info_text = 'an xarray dataset'
212
+ info_text = 'an xarray dataset or dataarray '
212
213
213
214
klass = None
214
215
@@ -232,15 +233,14 @@ def set(self, obj, value):
232
233
not old_value .equals (new_value )):
233
234
obj ._notify_trait (self .name , old_value , new_value )
234
235
235
- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
236
+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
236
237
if klass is None :
237
238
klass = self .klass
238
239
if (klass is not None ) and inspect .isclass (klass ):
239
240
self .klass = klass
240
241
else :
241
242
raise TraitError ('The klass attribute must be a class'
242
243
' not: %r' % klass )
243
- self .dtype = dtype
244
244
if default_value is Empty :
245
245
default_value = klass ()
246
246
elif default_value is not None and default_value is not Undefined :
@@ -266,3 +266,19 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
266
266
kwargs ['klass' ] = xr .Dataset
267
267
super (Dataset , self ).__init__ (
268
268
default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
269
+
270
+
271
+ class DataArray (XarrayType ):
272
+
273
+ """An xarray dataarray trait type."""
274
+
275
+ info_text = 'an xarray dataarray'
276
+ dtype = None
277
+
278
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
279
+ if 'klass' not in kwargs and self .klass is None :
280
+ import xarray as xr
281
+ kwargs ['klass' ] = xr .DataArray
282
+ super (DataArray , self ).__init__ (
283
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
284
+ self .dtype = dtype
0 commit comments