Skip to content

Commit 9ef8ff1

Browse files
committed
Adding a Series dataframe
1 parent 2668303 commit 9ef8ff1

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

traittypes/tests/test_traittypes.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest import TestCase
88
from traitlets import HasTraits, TraitError, observe
99
from traitlets.tests.test_traitlets import TraitTestBase
10-
from traittypes import Array, DataFrame
10+
from traittypes import Array, DataFrame, Series
1111
import numpy as np
1212
import pandas as pd
1313

@@ -138,3 +138,38 @@ class Foo(HasTraits):
138138
with self.assertRaises(TraitError):
139139
foo.bar = None
140140
foo.baz = None
141+
142+
143+
class TestSeries(TestCase):
144+
145+
def test_sereis_equal(self):
146+
notifications = []
147+
class Foo(HasTraits):
148+
bar = Series([1, 2])
149+
@observe('bar')
150+
def _(self, change):
151+
notifications.append(change)
152+
foo = Foo()
153+
foo.bar = [1, 2]
154+
self.assertFalse(len(notifications))
155+
foo.bar = [1, 1]
156+
self.assertTrue(len(notifications))
157+
158+
def test_initial_values(self):
159+
class Foo(HasTraits):
160+
a = Series()
161+
b = Series(None, allow_none=True)
162+
c = Series([])
163+
foo = Foo()
164+
self.assertTrue(foo.a.equals(pd.Series()))
165+
self.assertTrue(foo.b is None)
166+
self.assertTrue(foo.c.equals(pd.Series([])))
167+
168+
def test_allow_none(self):
169+
class Foo(HasTraits):
170+
bar = Series()
171+
baz = Series(allow_none=True)
172+
foo = Foo()
173+
with self.assertRaises(TraitError):
174+
foo.bar = None
175+
foo.baz = None

traittypes/traittypes.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,37 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
111111
default_value = pd.DataFrame(default_value)
112112
self.validators = []
113113
super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
114+
115+
116+
class Series(SciType):
117+
118+
"""A pandas series trait type."""
119+
120+
info_text = 'a pandas series'
121+
122+
def validate(self, obj, value):
123+
if value is None and not self.allow_none:
124+
self.error(obj, value)
125+
try:
126+
value = pd.Series(value)
127+
for validator in self.validators:
128+
value = validator(self, value)
129+
return value
130+
except (ValueError, TypeError) as e:
131+
raise TraitError(e)
132+
133+
def set(self, obj, value):
134+
new_value = self._validate(obj, value)
135+
old_value = obj._trait_values.get(self.name, self.default_value)
136+
obj._trait_values[self.name] = new_value
137+
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
138+
obj._notify_trait(self.name, old_value, new_value)
139+
140+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
141+
self.dtype = dtype
142+
if default_value is Undefined:
143+
default_value = pd.Series()
144+
elif default_value is not None:
145+
default_value = pd.Series(default_value)
146+
self.validators = []
147+
super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)

0 commit comments

Comments
 (0)