Skip to content

Commit eff30f4

Browse files
Merge pull request #23 from vidartf/tweaks
Further tweak definitions
2 parents 96d857f + c5e1af2 commit eff30f4

File tree

3 files changed

+115
-56
lines changed

3 files changed

+115
-56
lines changed

traittypes/tests/test_traittypes.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# Distributed under the terms of the Modified BSD License.
66

77
from unittest import TestCase
8-
from traitlets import HasTraits, TraitError, observe
8+
from traitlets import HasTraits, TraitError, observe, Undefined
99
from traitlets.tests.test_traitlets import TraitTestBase
1010
from traittypes import Array, DataFrame, Series
1111
import numpy as np
@@ -56,11 +56,13 @@ class Foo(HasTraits):
5656
b = Array(dtype='int')
5757
c = Array(None, allow_none=True)
5858
d = Array([])
59+
e = Array(Undefined)
5960
foo = Foo()
6061
self.assertTrue(np.array_equal(foo.a, np.array(0)))
6162
self.assertTrue(np.array_equal(foo.b, np.array(0)))
6263
self.assertTrue(foo.c is None)
6364
self.assertTrue(np.array_equal(foo.d, []))
65+
self.assertTrue(foo.e is Undefined)
6466

6567
def test_allow_none(self):
6668
class Foo(HasTraits):
@@ -116,19 +118,21 @@ def _(self, change):
116118
notifications.append(change)
117119
foo = Foo()
118120
foo.bar = [1, 2]
119-
self.assertFalse(len(notifications))
121+
self.assertEqual(notifications, [])
120122
foo.bar = [1, 1]
121-
self.assertTrue(len(notifications))
123+
self.assertEqual(len(notifications), 1)
122124

123125
def test_initial_values(self):
124126
class Foo(HasTraits):
125127
a = DataFrame()
126128
b = DataFrame(None, allow_none=True)
127129
c = DataFrame([])
130+
d = DataFrame(Undefined)
128131
foo = Foo()
129132
self.assertTrue(foo.a.equals(pd.DataFrame()))
130133
self.assertTrue(foo.b is None)
131134
self.assertTrue(foo.c.equals(pd.DataFrame([])))
135+
self.assertTrue(foo.d is Undefined)
132136

133137
def test_allow_none(self):
134138
class Foo(HasTraits):
@@ -142,7 +146,7 @@ class Foo(HasTraits):
142146

143147
class TestSeries(TestCase):
144148

145-
def test_sereis_equal(self):
149+
def test_series_equal(self):
146150
notifications = []
147151
class Foo(HasTraits):
148152
bar = Series([1, 2])
@@ -151,19 +155,21 @@ def _(self, change):
151155
notifications.append(change)
152156
foo = Foo()
153157
foo.bar = [1, 2]
154-
self.assertFalse(len(notifications))
158+
self.assertEqual(notifications, [])
155159
foo.bar = [1, 1]
156-
self.assertTrue(len(notifications))
160+
self.assertEqual(len(notifications), 1)
157161

158162
def test_initial_values(self):
159163
class Foo(HasTraits):
160164
a = Series()
161165
b = Series(None, allow_none=True)
162166
c = Series([])
167+
d = Series(Undefined)
163168
foo = Foo()
164169
self.assertTrue(foo.a.equals(pd.Series()))
165170
self.assertTrue(foo.b is None)
166171
self.assertTrue(foo.c.equals(pd.Series([])))
172+
self.assertTrue(foo.d is Undefined)
167173

168174
def test_allow_none(self):
169175
class Foo(HasTraits):

traittypes/tests/test_validators.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# Copyright (c) Jupyter Development Team.
5+
# Distributed under the terms of the Modified BSD License.
6+
7+
import pytest
8+
9+
from traitlets import HasTraits, TraitError
10+
11+
from ..traittypes import SciType
12+
13+
14+
def test_coercion_validator():
15+
# Test with a squeeze coercion
16+
def truncate(trait, value):
17+
return value[:10]
18+
19+
class Foo(HasTraits):
20+
bar = SciType().valid(truncate)
21+
22+
foo = Foo(bar=list(range(20)))
23+
assert foo.bar == list(range(10))
24+
foo.bar = list(range(10, 40))
25+
assert foo.bar == list(range(10, 20))
26+
27+
28+
def test_validaton_error():
29+
# Test with a squeeze coercion
30+
def maxlen(trait, value):
31+
if len(value) > 10:
32+
raise ValueError('Too long sequence!')
33+
return value
34+
35+
class Foo(HasTraits):
36+
bar = SciType().valid(maxlen)
37+
38+
# Check that it works as expected:
39+
foo = Foo(bar=list(range(5)))
40+
assert foo.bar == list(range(5))
41+
# Check that it fails as expected:
42+
with pytest.raises(TraitError): # Should convert ValueError to TraitError
43+
foo.bar = list(range(10, 40))
44+
assert foo.bar == list(range(5))
45+
# Check that it can again be set correctly
46+
foo = Foo(bar=list(range(5, 10)))
47+
assert foo.bar == list(range(5, 10))

traittypes/traittypes.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import inspect
12
import warnings
23

3-
from traitlets import TraitType, TraitError, Undefined
4+
from traitlets import TraitType, TraitError, Undefined, Sentinel
45

56
class _DelayedImportError(object):
67
def __init__(self, package_name):
@@ -20,6 +21,13 @@ def __getattribute__(self, name):
2021
pd = _DelayedImportError('pandas')
2122

2223

24+
Empty = Sentinel('Empty', 'traittypes',
25+
"""
26+
Used in traittypes to specify that the default value should
27+
be an empty dataset
28+
""")
29+
30+
2331
class SciType(TraitType):
2432

2533
"""A base trait type for numpy arrays, pandas dataframes and series."""
@@ -107,96 +115,94 @@ def set(self, obj, value):
107115
if not np.array_equal(old_value, new_value):
108116
obj._notify_trait(self.name, old_value, new_value)
109117

110-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
118+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
111119
self.dtype = dtype
112-
if default_value is Undefined:
120+
if default_value is Empty:
113121
default_value = np.array(0, dtype=self.dtype)
114-
elif default_value is not None:
122+
elif default_value is not None and default_value is not Undefined:
115123
default_value = np.asarray(default_value, dtype=self.dtype)
116124
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
117125

118126
def make_dynamic_default(self):
119-
if self.default_value is None:
127+
if self.default_value is None or self.default_value is Undefined:
120128
return self.default_value
121129
else:
122130
return np.copy(self.default_value)
123131

124132

125-
class DataFrame(SciType):
133+
class PandasType(SciType):
126134

127135
"""A pandas dataframe trait type."""
128136

129137
info_text = 'a pandas dataframe'
130138

139+
klass = None
140+
131141
def validate(self, obj, value):
132142
if value is None and not self.allow_none:
133143
self.error(obj, value)
134144
if value is None or value is Undefined:
135-
return super(DataFrame, self).validate(obj, value)
145+
return super(PandasType, self).validate(obj, value)
136146
try:
137-
value = pd.DataFrame(value)
147+
value = self.klass(value)
138148
except (ValueError, TypeError) as e:
139149
raise TraitError(e)
140-
return super(DataFrame, self).validate(obj, value)
150+
return super(PandasType, self).validate(obj, value)
141151

142152
def set(self, obj, value):
143153
new_value = self._validate(obj, value)
144154
old_value = obj._trait_values.get(self.name, self.default_value)
145155
obj._trait_values[self.name] = new_value
146-
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
156+
if ((old_value is None and new_value is not None) or
157+
(old_value is Undefined and new_value is not Undefined) or
158+
not old_value.equals(new_value)):
147159
obj._notify_trait(self.name, old_value, new_value)
148160

149-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
150-
import pandas as pd
161+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
162+
if klass is None:
163+
klass = self.klass
164+
if (klass is not None) and inspect.isclass(klass):
165+
self.klass = klass
166+
else:
167+
raise TraitError('The klass attribute must be a class'
168+
' not: %r' % klass)
151169
self.dtype = dtype
152-
if default_value is Undefined:
153-
default_value = pd.DataFrame()
154-
elif default_value is not None:
155-
default_value = pd.DataFrame(default_value)
156-
super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
170+
if default_value is Empty:
171+
default_value = klass()
172+
elif default_value is not None and default_value is not Undefined:
173+
default_value = klass(default_value)
174+
super(PandasType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
157175

158176
def make_dynamic_default(self):
159-
if self.default_value is None:
177+
if self.default_value is None or self.default_value is Undefined:
160178
return self.default_value
161179
else:
162180
return self.default_value.copy()
163181

164182

165-
class Series(SciType):
183+
class DataFrame(PandasType):
166184

167-
"""A pandas series trait type."""
185+
"""A pandas dataframe trait type."""
168186

169-
info_text = 'a pandas series'
187+
info_text = 'a pandas dataframe'
170188

171-
def validate(self, obj, value):
172-
if value is None and not self.allow_none:
173-
self.error(obj, value)
174-
if value is None or value is Undefined:
175-
return super(Series, self).validate(obj, value)
176-
try:
177-
value = pd.Series(value)
178-
except (ValueError, TypeError) as e:
179-
raise TraitError(e)
180-
return super(Series, self).validate(obj, value)
189+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
190+
if 'klass' not in kwargs and self.klass is None:
191+
import pandas as pd
192+
kwargs['klass'] = pd.DataFrame
193+
super(DataFrame, self).__init__(
194+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)
181195

182-
def set(self, obj, value):
183-
new_value = self._validate(obj, value)
184-
old_value = obj._trait_values.get(self.name, self.default_value)
185-
obj._trait_values[self.name] = new_value
186-
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
187-
obj._notify_trait(self.name, old_value, new_value)
188196

189-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
190-
import pandas as pd
191-
self.dtype = dtype
192-
if default_value is Undefined:
193-
default_value = pd.Series()
194-
elif default_value is not None:
195-
default_value = pd.Series(default_value)
196-
super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
197+
class Series(PandasType):
197198

198-
def make_dynamic_default(self):
199-
if self.default_value is None:
200-
return self.default_value
201-
else:
202-
return self.default_value.copy()
199+
"""A pandas series trait type."""
200+
201+
info_text = 'a pandas series'
202+
203+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
204+
if 'klass' not in kwargs and self.klass is None:
205+
import pandas as pd
206+
kwargs['klass'] = pd.Series
207+
super(Series, self).__init__(
208+
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)

0 commit comments

Comments
 (0)