Skip to content

Commit b6b1b67

Browse files
Merge pull request #11 from SylvainCorlay/DataFrame
Adding a DataFrame trait type
2 parents be9c606 + 9ef8ff1 commit b6b1b67

File tree

3 files changed

+177
-32
lines changed

3 files changed

+177
-32
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@
7878

7979
install_requires = setuptools_args['install_requires'] = [
8080
'traitlets>=4.2.2',
81-
'numpy'
81+
'numpy',
82+
'pandas'
8283
]
8384

8485
extras_require = setuptools_args['extras_require'] = {

traittypes/tests/test_traittypes.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from unittest import TestCase
88
from traitlets import HasTraits, TraitError, observe
99
from traitlets.tests.test_traitlets import TraitTestBase
10-
from traittypes import Array
10+
from traittypes import Array, DataFrame, Series
1111
import numpy as np
12+
import pandas as pd
1213

1314

1415
# Good / Bad value trait test cases
@@ -39,7 +40,7 @@ class TestArray(TestCase):
3940
def test_array_equal(self):
4041
notifications = []
4142
class Foo(HasTraits):
42-
bar = Array(default_value=[1, 2])
43+
bar = Array([1, 2])
4344
@observe('bar')
4445
def _(self, change):
4546
notifications.append(change)
@@ -103,3 +104,72 @@ class Foo(HasTraits):
103104
foo.bar = new_value
104105
self.assertTrue(np.array_equal(foo.bar, new_value))
105106

107+
108+
class TestDataFrame(TestCase):
109+
110+
def test_df_equal(self):
111+
notifications = []
112+
class Foo(HasTraits):
113+
bar = DataFrame([1, 2])
114+
@observe('bar')
115+
def _(self, change):
116+
notifications.append(change)
117+
foo = Foo()
118+
foo.bar = [1, 2]
119+
self.assertFalse(len(notifications))
120+
foo.bar = [1, 1]
121+
self.assertTrue(len(notifications))
122+
123+
def test_initial_values(self):
124+
class Foo(HasTraits):
125+
a = DataFrame()
126+
b = DataFrame(None, allow_none=True)
127+
c = DataFrame([])
128+
foo = Foo()
129+
self.assertTrue(foo.a.equals(pd.DataFrame()))
130+
self.assertTrue(foo.b is None)
131+
self.assertTrue(foo.c.equals(pd.DataFrame([])))
132+
133+
def test_allow_none(self):
134+
class Foo(HasTraits):
135+
bar = DataFrame()
136+
baz = DataFrame(allow_none=True)
137+
foo = Foo()
138+
with self.assertRaises(TraitError):
139+
foo.bar = None
140+
foo.baz = None
141+
142+
143+
class TestSeries(TestCase):
144+
145+
def test_sereis_equal(self):
146+
notifications = []
147+
class Foo(HasTraits):
148+
bar = Series([1, 2])
149+
@observe('bar')
150+
def _(self, change):
151+
notifications.append(change)
152+
foo = Foo()
153+
foo.bar = [1, 2]
154+
self.assertFalse(len(notifications))
155+
foo.bar = [1, 1]
156+
self.assertTrue(len(notifications))
157+
158+
def test_initial_values(self):
159+
class Foo(HasTraits):
160+
a = Series()
161+
b = Series(None, allow_none=True)
162+
c = Series([])
163+
foo = Foo()
164+
self.assertTrue(foo.a.equals(pd.Series()))
165+
self.assertTrue(foo.b is None)
166+
self.assertTrue(foo.c.equals(pd.Series([])))
167+
168+
def test_allow_none(self):
169+
class Foo(HasTraits):
170+
bar = Series()
171+
baz = Series(allow_none=True)
172+
foo = Foo()
173+
with self.assertRaises(TraitError):
174+
foo.bar = None
175+
foo.baz = None

traittypes/traittypes.py

Lines changed: 103 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,50 @@
11
from traitlets import TraitType, TraitError, Undefined
22
import numpy as np
3+
import pandas as pd
34

45

5-
class Array(TraitType):
6+
class SciType(TraitType):
7+
8+
"""A base traittype for numpy arrays, pandas dataframes and series."""
9+
10+
def valid(self, *validators):
11+
"""
12+
Register new trait validators
13+
14+
Validators are functions that take two arguments.
15+
- The trait instance
16+
- The proposed value
17+
18+
Validators return the (potentially modified) value, which is either
19+
assigned to the HasTraits attribute or input into the next validator.
20+
21+
They are evaluated in the order in which they are provided to the `valid`
22+
function.
23+
24+
Example
25+
-------
26+
27+
.. code-block:: python
28+
# Test with a shape constraint
29+
def shape(*dimensions):
30+
def validator(trait, value):
31+
if value.shape != dimensions:
32+
raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape))
33+
else:
34+
return value
35+
return validator
36+
37+
class Foo(HasTraits):
38+
bar = Array(np.identity(2)).valid(shape(2, 2))
39+
foo = Foo()
40+
41+
foo.bar = [1, 2] # Should raise a TraitError
42+
"""
43+
self.validators.extend(validators)
44+
return self
45+
46+
47+
class Array(SciType):
648

749
"""A numpy array trait type."""
850

@@ -36,38 +78,70 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
3678
self.validators = []
3779
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
3880

39-
def valid(self, *validators):
40-
"""
41-
Register new trait validators
4281

43-
Validators are functions that take two arguments.
44-
- The trait instance
45-
- The proposed value
82+
class DataFrame(SciType):
4683

47-
Validators return the (potentially modified) value, which is either
48-
assigned to the HasTraits attribute or input into the next validator.
84+
"""A pandas dataframe trait type."""
4985

50-
They are evaluated in the order in which they are provided to the `valid`
51-
function.
86+
info_text = 'a pandas dataframe'
5287

53-
Example
54-
-------
88+
def validate(self, obj, value):
89+
if value is None and not self.allow_none:
90+
self.error(obj, value)
91+
try:
92+
value = pd.DataFrame(value)
93+
for validator in self.validators:
94+
value = validator(self, value)
95+
return value
96+
except (ValueError, TypeError) as e:
97+
raise TraitError(e)
5598

56-
.. code-block:: python
57-
# Test with a shape constraint
58-
def shape(*dimensions):
59-
def validator(trait, value):
60-
if value.shape != dimensions:
61-
raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape))
62-
else:
63-
return value
64-
return validator
99+
def set(self, obj, value):
100+
new_value = self._validate(obj, value)
101+
old_value = obj._trait_values.get(self.name, self.default_value)
102+
obj._trait_values[self.name] = new_value
103+
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
104+
obj._notify_trait(self.name, old_value, new_value)
65105

66-
class Foo(HasTraits):
67-
bar = Array(np.identity(2)).valid(shape(2, 2))
68-
foo = Foo()
106+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
107+
self.dtype = dtype
108+
if default_value is Undefined:
109+
default_value = pd.DataFrame()
110+
elif default_value is not None:
111+
default_value = pd.DataFrame(default_value)
112+
self.validators = []
113+
super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
69114

70-
foo.bar = [1, 2] # Should raise a TraitError
71-
"""
72-
self.validators.extend(validators)
73-
return self
115+
116+
class Series(SciType):
117+
118+
"""A pandas series trait type."""
119+
120+
info_text = 'a pandas series'
121+
122+
def validate(self, obj, value):
123+
if value is None and not self.allow_none:
124+
self.error(obj, value)
125+
try:
126+
value = pd.Series(value)
127+
for validator in self.validators:
128+
value = validator(self, value)
129+
return value
130+
except (ValueError, TypeError) as e:
131+
raise TraitError(e)
132+
133+
def set(self, obj, value):
134+
new_value = self._validate(obj, value)
135+
old_value = obj._trait_values.get(self.name, self.default_value)
136+
obj._trait_values[self.name] = new_value
137+
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
138+
obj._notify_trait(self.name, old_value, new_value)
139+
140+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
141+
self.dtype = dtype
142+
if default_value is Undefined:
143+
default_value = pd.Series()
144+
elif default_value is not None:
145+
default_value = pd.Series(default_value)
146+
self.validators = []
147+
super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)

0 commit comments

Comments
 (0)