Skip to content

Commit 96d3b76

Browse files
committed
Refactor commond pandas trait type (DRY)
1 parent 0b5d508 commit 96d3b76

File tree

1 file changed

+38
-41
lines changed

1 file changed

+38
-41
lines changed

traittypes/traittypes.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import warnings
23

34
from traitlets import TraitType, TraitError, Undefined
@@ -122,22 +123,24 @@ def make_dynamic_default(self):
122123
return np.copy(self.default_value)
123124

124125

125-
class DataFrame(SciType):
126+
class PandasType(SciType):
126127

127128
"""A pandas dataframe trait type."""
128129

129130
info_text = 'a pandas dataframe'
130131

132+
klass = None
133+
131134
def validate(self, obj, value):
132135
if value is None and not self.allow_none:
133136
self.error(obj, value)
134137
if value is None or value is Undefined:
135-
return super(DataFrame, self).validate(obj, value)
138+
return super(PandasType, self).validate(obj, value)
136139
try:
137-
value = pd.DataFrame(value)
140+
value = self.klass(value)
138141
except (ValueError, TypeError) as e:
139142
raise TraitError(e)
140-
return super(DataFrame, self).validate(obj, value)
143+
return super(PandasType, self).validate(obj, value)
141144

142145
def set(self, obj, value):
143146
new_value = self._validate(obj, value)
@@ -146,14 +149,20 @@ def set(self, obj, value):
146149
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
147150
obj._notify_trait(self.name, old_value, new_value)
148151

149-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
150-
import pandas as pd
152+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, klass=None, **kwargs):
153+
if klass is None:
154+
klass = self.klass
155+
if (klass is not None) and inspect.isclass(klass):
156+
self.klass = klass
157+
else:
158+
raise TraitError('The klass attribute must be a class'
159+
' not: %r' % klass)
151160
self.dtype = dtype
152161
if default_value is Undefined:
153-
default_value = pd.DataFrame()
162+
default_value = klass()
154163
elif default_value is not None:
155-
default_value = pd.DataFrame(default_value)
156-
super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
164+
default_value = klass(default_value)
165+
super(PandasType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
157166

158167
def make_dynamic_default(self):
159168
if self.default_value is None:
@@ -162,41 +171,29 @@ def make_dynamic_default(self):
162171
return self.default_value.copy()
163172

164173

165-
class Series(SciType):
174+
class DataFrame(PandasType):
166175

167-
"""A pandas series trait type."""
176+
"""A pandas dataframe trait type."""
168177

169-
info_text = 'a pandas series'
178+
info_text = 'a pandas dataframe'
170179

171-
def validate(self, obj, value):
172-
if value is None and not self.allow_none:
173-
self.error(obj, value)
174-
if value is None or value is Undefined:
175-
return super(Series, self).validate(obj, value)
176-
try:
177-
value = pd.Series(value)
178-
except (ValueError, TypeError) as e:
179-
raise TraitError(e)
180-
return super(Series, self).validate(obj, value)
180+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
181+
if 'klass' not in kwargs and self.klass is None:
182+
import pandas as pd
183+
kwargs['klass'] = pd.DataFrame
184+
super(DataFrame, self).__init__(
185+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
181186

182-
def set(self, obj, value):
183-
new_value = self._validate(obj, value)
184-
old_value = obj._trait_values.get(self.name, self.default_value)
185-
obj._trait_values[self.name] = new_value
186-
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
187-
obj._notify_trait(self.name, old_value, new_value)
188187

189-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
190-
import pandas as pd
191-
self.dtype = dtype
192-
if default_value is Undefined:
193-
default_value = pd.Series()
194-
elif default_value is not None:
195-
default_value = pd.Series(default_value)
196-
super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
188+
class Series(PandasType):
197189

198-
def make_dynamic_default(self):
199-
if self.default_value is None:
200-
return self.default_value
201-
else:
202-
return self.default_value.copy()
190+
"""A pandas series trait type."""
191+
192+
info_text = 'a pandas series'
193+
194+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
195+
if 'klass' not in kwargs and self.klass is None:
196+
import pandas as pd
197+
kwargs['klass'] = pd.Series
198+
super(Series, self).__init__(
199+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)

0 commit comments

Comments
 (0)