@@ -15,10 +15,6 @@ def __getattribute__(self, name):
15
15
import numpy as np
16
16
except ImportError :
17
17
np = _DelayedImportError ('numpy' )
18
- try :
19
- import pandas as pd
20
- except ImportError :
21
- pd = _DelayedImportError ('pandas' )
22
18
23
19
24
20
Empty = Sentinel ('Empty' , 'traittypes' ,
@@ -30,7 +26,7 @@ def __getattribute__(self, name):
30
26
31
27
class SciType (TraitType ):
32
28
33
- """A base trait type for numpy arrays, pandas dataframes and series."""
29
+ """A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays ."""
34
30
35
31
def __init__ (self , ** kwargs ):
36
32
super (SciType , self ).__init__ (** kwargs )
@@ -132,9 +128,9 @@ def make_dynamic_default(self):
132
128
133
129
class PandasType (SciType ):
134
130
135
- """A pandas dataframe trait type."""
131
+ """A pandas dataframe or series trait type."""
136
132
137
- info_text = 'a pandas dataframe'
133
+ info_text = 'a pandas dataframe or series '
138
134
139
135
klass = None
140
136
@@ -158,15 +154,14 @@ def set(self, obj, value):
158
154
not old_value .equals (new_value )):
159
155
obj ._notify_trait (self .name , old_value , new_value )
160
156
161
- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
157
+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
162
158
if klass is None :
163
159
klass = self .klass
164
160
if (klass is not None ) and inspect .isclass (klass ):
165
161
self .klass = klass
166
162
else :
167
163
raise TraitError ('The klass attribute must be a class'
168
164
' not: %r' % klass )
169
- self .dtype = dtype
170
165
if default_value is Empty :
171
166
default_value = klass ()
172
167
elif default_value is not None and default_value is not Undefined :
@@ -199,10 +194,91 @@ class Series(PandasType):
199
194
"""A pandas series trait type."""
200
195
201
196
info_text = 'a pandas series'
197
+ dtype = None
202
198
203
199
def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
204
200
if 'klass' not in kwargs and self .klass is None :
205
201
import pandas as pd
206
202
kwargs ['klass' ] = pd .Series
207
203
super (Series , self ).__init__ (
208
204
default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
205
+ self .dtype = dtype
206
+
207
+
208
+ class XarrayType (SciType ):
209
+
210
+ """An xarray dataset or dataarray trait type."""
211
+
212
+ info_text = 'an xarray dataset or dataarray'
213
+
214
+ klass = None
215
+
216
+ def validate (self , obj , value ):
217
+ if value is None and not self .allow_none :
218
+ self .error (obj , value )
219
+ if value is None or value is Undefined :
220
+ return super (XarrayType , self ).validate (obj , value )
221
+ try :
222
+ value = self .klass (value )
223
+ except (ValueError , TypeError ) as e :
224
+ raise TraitError (e )
225
+ return super (XarrayType , self ).validate (obj , value )
226
+
227
+ def set (self , obj , value ):
228
+ new_value = self ._validate (obj , value )
229
+ old_value = obj ._trait_values .get (self .name , self .default_value )
230
+ obj ._trait_values [self .name ] = new_value
231
+ if ((old_value is None and new_value is not None ) or
232
+ (old_value is Undefined and new_value is not Undefined ) or
233
+ not old_value .equals (new_value )):
234
+ obj ._notify_trait (self .name , old_value , new_value )
235
+
236
+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
237
+ if klass is None :
238
+ klass = self .klass
239
+ if (klass is not None ) and inspect .isclass (klass ):
240
+ self .klass = klass
241
+ else :
242
+ raise TraitError ('The klass attribute must be a class'
243
+ ' not: %r' % klass )
244
+ if default_value is Empty :
245
+ default_value = klass ()
246
+ elif default_value is not None and default_value is not Undefined :
247
+ default_value = klass (default_value )
248
+ super (XarrayType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
249
+
250
+ def make_dynamic_default (self ):
251
+ if self .default_value is None or self .default_value is Undefined :
252
+ return self .default_value
253
+ else :
254
+ return self .default_value .copy ()
255
+
256
+
257
+ class Dataset (XarrayType ):
258
+
259
+ """An xarray dataset trait type."""
260
+
261
+ info_text = 'an xarray dataset'
262
+
263
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
264
+ if 'klass' not in kwargs and self .klass is None :
265
+ import xarray as xr
266
+ kwargs ['klass' ] = xr .Dataset
267
+ super (Dataset , self ).__init__ (
268
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
269
+
270
+
271
+ class DataArray (XarrayType ):
272
+
273
+ """An xarray dataarray trait type."""
274
+
275
+ info_text = 'an xarray dataarray'
276
+ dtype = None
277
+
278
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
279
+ if 'klass' not in kwargs and self .klass is None :
280
+ import xarray as xr
281
+ kwargs ['klass' ] = xr .DataArray
282
+ super (DataArray , self ).__init__ (
283
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
284
+ self .dtype = dtype
0 commit comments