1
+ import inspect
1
2
import warnings
2
3
3
- from traitlets import TraitType , TraitError , Undefined
4
+ from traitlets import TraitType , TraitError , Undefined , Sentinel
4
5
5
6
class _DelayedImportError (object ):
6
7
def __init__ (self , package_name ):
@@ -20,6 +21,13 @@ def __getattribute__(self, name):
20
21
pd = _DelayedImportError ('pandas' )
21
22
22
23
24
+ Empty = Sentinel ('Empty' , 'traittypes' ,
25
+ """
26
+ Used in traittypes to specify that the default value should
27
+ be an empty dataset
28
+ """ )
29
+
30
+
23
31
class SciType (TraitType ):
24
32
25
33
"""A base trait type for numpy arrays, pandas dataframes and series."""
@@ -107,96 +115,94 @@ def set(self, obj, value):
107
115
if not np .array_equal (old_value , new_value ):
108
116
obj ._notify_trait (self .name , old_value , new_value )
109
117
110
- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
118
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
111
119
self .dtype = dtype
112
- if default_value is Undefined :
120
+ if default_value is Empty :
113
121
default_value = np .array (0 , dtype = self .dtype )
114
- elif default_value is not None :
122
+ elif default_value is not None and default_value is not Undefined :
115
123
default_value = np .asarray (default_value , dtype = self .dtype )
116
124
super (Array , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
117
125
118
126
def make_dynamic_default (self ):
119
- if self .default_value is None :
127
+ if self .default_value is None or self . default_value is Undefined :
120
128
return self .default_value
121
129
else :
122
130
return np .copy (self .default_value )
123
131
124
132
125
- class DataFrame (SciType ):
133
+ class PandasType (SciType ):
126
134
127
135
"""A pandas dataframe trait type."""
128
136
129
137
info_text = 'a pandas dataframe'
130
138
139
+ klass = None
140
+
131
141
def validate (self , obj , value ):
132
142
if value is None and not self .allow_none :
133
143
self .error (obj , value )
134
144
if value is None or value is Undefined :
135
- return super (DataFrame , self ).validate (obj , value )
145
+ return super (PandasType , self ).validate (obj , value )
136
146
try :
137
- value = pd . DataFrame (value )
147
+ value = self . klass (value )
138
148
except (ValueError , TypeError ) as e :
139
149
raise TraitError (e )
140
- return super (DataFrame , self ).validate (obj , value )
150
+ return super (PandasType , self ).validate (obj , value )
141
151
142
152
def set (self , obj , value ):
143
153
new_value = self ._validate (obj , value )
144
154
old_value = obj ._trait_values .get (self .name , self .default_value )
145
155
obj ._trait_values [self .name ] = new_value
146
- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
156
+ if ((old_value is None and new_value is not None ) or
157
+ (old_value is Undefined and new_value is not Undefined ) or
158
+ not old_value .equals (new_value )):
147
159
obj ._notify_trait (self .name , old_value , new_value )
148
160
149
- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
150
- import pandas as pd
161
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
162
+ if klass is None :
163
+ klass = self .klass
164
+ if (klass is not None ) and inspect .isclass (klass ):
165
+ self .klass = klass
166
+ else :
167
+ raise TraitError ('The klass attribute must be a class'
168
+ ' not: %r' % klass )
151
169
self .dtype = dtype
152
- if default_value is Undefined :
153
- default_value = pd . DataFrame ()
154
- elif default_value is not None :
155
- default_value = pd . DataFrame (default_value )
156
- super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
170
+ if default_value is Empty :
171
+ default_value = klass ()
172
+ elif default_value is not None and default_value is not Undefined :
173
+ default_value = klass (default_value )
174
+ super (PandasType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
157
175
158
176
def make_dynamic_default (self ):
159
- if self .default_value is None :
177
+ if self .default_value is None or self . default_value is Undefined :
160
178
return self .default_value
161
179
else :
162
180
return self .default_value .copy ()
163
181
164
182
165
- class Series ( SciType ):
183
+ class DataFrame ( PandasType ):
166
184
167
- """A pandas series trait type."""
185
+ """A pandas dataframe trait type."""
168
186
169
- info_text = 'a pandas series '
187
+ info_text = 'a pandas dataframe '
170
188
171
- def validate (self , obj , value ):
172
- if value is None and not self .allow_none :
173
- self .error (obj , value )
174
- if value is None or value is Undefined :
175
- return super (Series , self ).validate (obj , value )
176
- try :
177
- value = pd .Series (value )
178
- except (ValueError , TypeError ) as e :
179
- raise TraitError (e )
180
- return super (Series , self ).validate (obj , value )
189
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
190
+ if 'klass' not in kwargs and self .klass is None :
191
+ import pandas as pd
192
+ kwargs ['klass' ] = pd .DataFrame
193
+ super (DataFrame , self ).__init__ (
194
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
181
195
182
- def set (self , obj , value ):
183
- new_value = self ._validate (obj , value )
184
- old_value = obj ._trait_values .get (self .name , self .default_value )
185
- obj ._trait_values [self .name ] = new_value
186
- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
187
- obj ._notify_trait (self .name , old_value , new_value )
188
196
189
- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
190
- import pandas as pd
191
- self .dtype = dtype
192
- if default_value is Undefined :
193
- default_value = pd .Series ()
194
- elif default_value is not None :
195
- default_value = pd .Series (default_value )
196
- super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
197
+ class Series (PandasType ):
197
198
198
- def make_dynamic_default (self ):
199
- if self .default_value is None :
200
- return self .default_value
201
- else :
202
- return self .default_value .copy ()
199
+ """A pandas series trait type."""
200
+
201
+ info_text = 'a pandas series'
202
+
203
+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
204
+ if 'klass' not in kwargs and self .klass is None :
205
+ import pandas as pd
206
+ kwargs ['klass' ] = pd .Series
207
+ super (Series , self ).__init__ (
208
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
0 commit comments