Skip to content

Commit c5e1af2

Browse files
committed
Allow Undefined as default
Introduces a new sentinel for default behavior of initializing an empty data structure.
1 parent 96d3b76 commit c5e1af2

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

traittypes/tests/test_traittypes.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# Distributed under the terms of the Modified BSD License.
66

77
from unittest import TestCase
8-
from traitlets import HasTraits, TraitError, observe
8+
from traitlets import HasTraits, TraitError, observe, Undefined
99
from traitlets.tests.test_traitlets import TraitTestBase
1010
from traittypes import Array, DataFrame, Series
1111
import numpy as np
@@ -56,11 +56,13 @@ class Foo(HasTraits):
5656
b = Array(dtype='int')
5757
c = Array(None, allow_none=True)
5858
d = Array([])
59+
e = Array(Undefined)
5960
foo = Foo()
6061
self.assertTrue(np.array_equal(foo.a, np.array(0)))
6162
self.assertTrue(np.array_equal(foo.b, np.array(0)))
6263
self.assertTrue(foo.c is None)
6364
self.assertTrue(np.array_equal(foo.d, []))
65+
self.assertTrue(foo.e is Undefined)
6466

6567
def test_allow_none(self):
6668
class Foo(HasTraits):
@@ -116,19 +118,21 @@ def _(self, change):
116118
notifications.append(change)
117119
foo = Foo()
118120
foo.bar = [1, 2]
119-
self.assertFalse(len(notifications))
121+
self.assertEqual(notifications, [])
120122
foo.bar = [1, 1]
121-
self.assertTrue(len(notifications))
123+
self.assertEqual(len(notifications), 1)
122124

123125
def test_initial_values(self):
124126
class Foo(HasTraits):
125127
a = DataFrame()
126128
b = DataFrame(None, allow_none=True)
127129
c = DataFrame([])
130+
d = DataFrame(Undefined)
128131
foo = Foo()
129132
self.assertTrue(foo.a.equals(pd.DataFrame()))
130133
self.assertTrue(foo.b is None)
131134
self.assertTrue(foo.c.equals(pd.DataFrame([])))
135+
self.assertTrue(foo.d is Undefined)
132136

133137
def test_allow_none(self):
134138
class Foo(HasTraits):
@@ -142,7 +146,7 @@ class Foo(HasTraits):
142146

143147
class TestSeries(TestCase):
144148

145-
def test_sereis_equal(self):
149+
def test_series_equal(self):
146150
notifications = []
147151
class Foo(HasTraits):
148152
bar = Series([1, 2])
@@ -151,19 +155,21 @@ def _(self, change):
151155
notifications.append(change)
152156
foo = Foo()
153157
foo.bar = [1, 2]
154-
self.assertFalse(len(notifications))
158+
self.assertEqual(notifications, [])
155159
foo.bar = [1, 1]
156-
self.assertTrue(len(notifications))
160+
self.assertEqual(len(notifications), 1)
157161

158162
def test_initial_values(self):
159163
class Foo(HasTraits):
160164
a = Series()
161165
b = Series(None, allow_none=True)
162166
c = Series([])
167+
d = Series(Undefined)
163168
foo = Foo()
164169
self.assertTrue(foo.a.equals(pd.Series()))
165170
self.assertTrue(foo.b is None)
166171
self.assertTrue(foo.c.equals(pd.Series([])))
172+
self.assertTrue(foo.d is Undefined)
167173

168174
def test_allow_none(self):
169175
class Foo(HasTraits):

traittypes/traittypes.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import warnings
33

4-
from traitlets import TraitType, TraitError, Undefined
4+
from traitlets import TraitType, TraitError, Undefined, Sentinel
55

66
class _DelayedImportError(object):
77
def __init__(self, package_name):
@@ -21,6 +21,13 @@ def __getattribute__(self, name):
2121
pd = _DelayedImportError('pandas')
2222

2323

24+
Empty = Sentinel('Empty', 'traittypes',
25+
"""
26+
Used in traittypes to specify that the default value should
27+
be an empty dataset
28+
""")
29+
30+
2431
class SciType(TraitType):
2532

2633
"""A base trait type for numpy arrays, pandas dataframes and series."""
@@ -108,16 +115,16 @@ def set(self, obj, value):
108115
if not np.array_equal(old_value, new_value):
109116
obj._notify_trait(self.name, old_value, new_value)
110117

111-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
118+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
112119
self.dtype = dtype
113-
if default_value is Undefined:
120+
if default_value is Empty:
114121
default_value = np.array(0, dtype=self.dtype)
115-
elif default_value is not None:
122+
elif default_value is not None and default_value is not Undefined:
116123
default_value = np.asarray(default_value, dtype=self.dtype)
117124
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
118125

119126
def make_dynamic_default(self):
120-
if self.default_value is None:
127+
if self.default_value is None or self.default_value is Undefined:
121128
return self.default_value
122129
else:
123130
return np.copy(self.default_value)
@@ -146,10 +153,12 @@ def set(self, obj, value):
146153
new_value = self._validate(obj, value)
147154
old_value = obj._trait_values.get(self.name, self.default_value)
148155
obj._trait_values[self.name] = new_value
149-
if (old_value is None and new_value is not None) or not old_value.equals(new_value):
156+
if ((old_value is None and new_value is not None) or
157+
(old_value is Undefined and new_value is not Undefined) or
158+
not old_value.equals(new_value)):
150159
obj._notify_trait(self.name, old_value, new_value)
151160

152-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, klass=None, **kwargs):
161+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
153162
if klass is None:
154163
klass = self.klass
155164
if (klass is not None) and inspect.isclass(klass):
@@ -158,14 +167,14 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, klass=
158167
raise TraitError('The klass attribute must be a class'
159168
' not: %r' % klass)
160169
self.dtype = dtype
161-
if default_value is Undefined:
170+
if default_value is Empty:
162171
default_value = klass()
163-
elif default_value is not None:
172+
elif default_value is not None and default_value is not Undefined:
164173
default_value = klass(default_value)
165174
super(PandasType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
166175

167176
def make_dynamic_default(self):
168-
if self.default_value is None:
177+
if self.default_value is None or self.default_value is Undefined:
169178
return self.default_value
170179
else:
171180
return self.default_value.copy()
@@ -177,7 +186,7 @@ class DataFrame(PandasType):
177186

178187
info_text = 'a pandas dataframe'
179188

180-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
189+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
181190
if 'klass' not in kwargs and self.klass is None:
182191
import pandas as pd
183192
kwargs['klass'] = pd.DataFrame
@@ -191,7 +200,7 @@ class Series(PandasType):
191200

192201
info_text = 'a pandas series'
193202

194-
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
203+
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
195204
if 'klass' not in kwargs and self.klass is None:
196205
import pandas as pd
197206
kwargs['klass'] = pd.Series

0 commit comments

Comments
 (0)