1
+ import inspect
1
2
import warnings
2
3
3
4
from traitlets import TraitType , TraitError , Undefined
@@ -122,22 +123,24 @@ def make_dynamic_default(self):
122
123
return np .copy (self .default_value )
123
124
124
125
125
- class DataFrame (SciType ):
126
+ class PandasType (SciType ):
126
127
127
128
"""A pandas dataframe trait type."""
128
129
129
130
info_text = 'a pandas dataframe'
130
131
132
+ klass = None
133
+
131
134
def validate (self , obj , value ):
132
135
if value is None and not self .allow_none :
133
136
self .error (obj , value )
134
137
if value is None or value is Undefined :
135
- return super (DataFrame , self ).validate (obj , value )
138
+ return super (PandasType , self ).validate (obj , value )
136
139
try :
137
- value = pd . DataFrame (value )
140
+ value = self . klass (value )
138
141
except (ValueError , TypeError ) as e :
139
142
raise TraitError (e )
140
- return super (DataFrame , self ).validate (obj , value )
143
+ return super (PandasType , self ).validate (obj , value )
141
144
142
145
def set (self , obj , value ):
143
146
new_value = self ._validate (obj , value )
@@ -146,14 +149,20 @@ def set(self, obj, value):
146
149
if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
147
150
obj ._notify_trait (self .name , old_value , new_value )
148
151
149
- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
150
- import pandas as pd
152
+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , klass = None , ** kwargs ):
153
+ if klass is None :
154
+ klass = self .klass
155
+ if (klass is not None ) and inspect .isclass (klass ):
156
+ self .klass = klass
157
+ else :
158
+ raise TraitError ('The klass attribute must be a class'
159
+ ' not: %r' % klass )
151
160
self .dtype = dtype
152
161
if default_value is Undefined :
153
- default_value = pd . DataFrame ()
162
+ default_value = klass ()
154
163
elif default_value is not None :
155
- default_value = pd . DataFrame (default_value )
156
- super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
164
+ default_value = klass (default_value )
165
+ super (PandasType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
157
166
158
167
def make_dynamic_default (self ):
159
168
if self .default_value is None :
@@ -162,41 +171,29 @@ def make_dynamic_default(self):
162
171
return self .default_value .copy ()
163
172
164
173
165
- class Series ( SciType ):
174
+ class DataFrame ( PandasType ):
166
175
167
- """A pandas series trait type."""
176
+ """A pandas dataframe trait type."""
168
177
169
- info_text = 'a pandas series '
178
+ info_text = 'a pandas dataframe '
170
179
171
- def validate (self , obj , value ):
172
- if value is None and not self .allow_none :
173
- self .error (obj , value )
174
- if value is None or value is Undefined :
175
- return super (Series , self ).validate (obj , value )
176
- try :
177
- value = pd .Series (value )
178
- except (ValueError , TypeError ) as e :
179
- raise TraitError (e )
180
- return super (Series , self ).validate (obj , value )
180
+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
181
+ if 'klass' not in kwargs and self .klass is None :
182
+ import pandas as pd
183
+ kwargs ['klass' ] = pd .DataFrame
184
+ super (DataFrame , self ).__init__ (
185
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
181
186
182
- def set (self , obj , value ):
183
- new_value = self ._validate (obj , value )
184
- old_value = obj ._trait_values .get (self .name , self .default_value )
185
- obj ._trait_values [self .name ] = new_value
186
- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
187
- obj ._notify_trait (self .name , old_value , new_value )
188
187
189
- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
190
- import pandas as pd
191
- self .dtype = dtype
192
- if default_value is Undefined :
193
- default_value = pd .Series ()
194
- elif default_value is not None :
195
- default_value = pd .Series (default_value )
196
- super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
188
+ class Series (PandasType ):
197
189
198
- def make_dynamic_default (self ):
199
- if self .default_value is None :
200
- return self .default_value
201
- else :
202
- return self .default_value .copy ()
190
+ """A pandas series trait type."""
191
+
192
+ info_text = 'a pandas series'
193
+
194
+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
195
+ if 'klass' not in kwargs and self .klass is None :
196
+ import pandas as pd
197
+ kwargs ['klass' ] = pd .Series
198
+ super (Series , self ).__init__ (
199
+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
0 commit comments