Skip to content

Commit 67bbc70

Browse files
committed
Better handling of default values
1 parent d7814f8 commit 67bbc70

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

traittypes/tests/test_traittypes.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# Distributed under the terms of the Modified BSD License.
66

77
from unittest import TestCase
8-
from traitlets import HasTraits, observe
8+
from traitlets import HasTraits, TraitError, observe
99
from traitlets.tests.test_traitlets import TraitTestBase
1010
from traittypes import Array
1111
import numpy as np
@@ -39,15 +39,34 @@ class TestArray(TestCase):
3939

4040
def test_array_equal(self):
4141
notifications = []
42-
4342
class Foo(HasTraits):
4443
bar = Array(default_value=[1, 2])
4544
@observe('bar')
4645
def _(self, change):
4746
notifications.append(change)
48-
4947
foo = Foo()
5048
foo.bar = [1, 2]
5149
self.assertFalse(len(notifications))
5250
foo.bar = [1, 1]
5351
self.assertTrue(len(notifications))
52+
53+
def test_initial_values(self):
54+
class Foo(HasTraits):
55+
a = Array()
56+
b = Array(dtype='int')
57+
c = Array(None, allow_none=True)
58+
d = Array([])
59+
foo = Foo()
60+
self.assertTrue(np.array_equal(foo.a, np.array(0)))
61+
self.assertTrue(np.array_equal(foo.b, np.array(0)))
62+
self.assertTrue(foo.c is None)
63+
self.assertTrue(np.array_equal(foo.d, []))
64+
65+
def test_allow_none(self):
66+
class Foo(HasTraits):
67+
bar = Array()
68+
baz = Array(allow_none=True)
69+
foo = Foo()
70+
with self.assertRaises(TraitError):
71+
foo.bar = None
72+
foo.baz = None

traittypes/traittypes.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from traitlets import TraitType, TraitError
1+
from traitlets import TraitType, TraitError, Undefined
22
import numpy as np
33

44

@@ -7,11 +7,13 @@ class Array(TraitType):
77
"""A numpy array trait type."""
88

99
info_text = 'a numpy array'
10+
dtype = None
1011

1112
def validate(self, obj, value):
13+
if value is None and not self.allow_none:
14+
self.error(obj, value)
1215
try:
13-
return np.asarray(value, dtype=self.get_metadata('dtype'),
14-
order=self.get_metadata('order'))
16+
return np.asarray(value, dtype=self.dtype)
1517
except (ValueError, TypeError) as e:
1618
raise TraitError(e)
1719

@@ -21,3 +23,12 @@ def set(self, obj, value):
2123
obj._trait_values[self.name] = new_value
2224
if not np.array_equal(old_value, new_value):
2325
obj._notify_trait(self.name, old_value, new_value)
26+
27+
def __init__(self, default_value=Undefined, allow_none=False,
28+
dtype=None, **kwargs):
29+
self.dtype = dtype
30+
if default_value is Undefined:
31+
default_value = np.array(0, dtype=self.dtype)
32+
elif default_value is not None:
33+
default_value = np.asarray(default_value, dtype=self.dtype)
34+
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)

0 commit comments

Comments
 (0)