Skip to content

Commit d7814f8

Browse files
Merge pull request #3 from SylvainCorlay/overload_set
Overload set
2 parents 2b4b77f + 29cdee4 commit d7814f8

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

traittypes/tests/test_traittypes.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,49 @@
55
# Distributed under the terms of the Modified BSD License.
66

77
from unittest import TestCase
8-
from traitlets import HasTraits
8+
from traitlets import HasTraits, observe
99
from traitlets.tests.test_traitlets import TraitTestBase
1010
from traittypes import Array
1111
import numpy as np
1212

1313

14-
class ArrayTraitTestBase(TraitTestBase):
15-
"""A best testing class for numpy trait types.
16-
17-
:meth:`assertEqual` is overloaded to not use the `__eq__` operator.
18-
"""
19-
20-
def assertEqual(self, v1, v2):
21-
return np.testing.assert_array_equal(v1, v2)
14+
# Good / Bad value trait test cases
2215

2316

2417
class IntArrayTrait(HasTraits):
2518
value = Array().tag(dtype=np.int)
2619

2720

28-
class TestIntArray(ArrayTraitTestBase):
29-
"""Test d-type validation with a ``dtype=np.int``."""
21+
class TestIntArray(TraitTestBase):
22+
"""
23+
Test dtype validation with a ``dtype=np.int``
24+
"""
3025
obj = IntArrayTrait()
3126

3227
_good_values = [1, [1, 2, 3], [[1, 2, 3], [4, 5, 6]], np.array([1])]
3328
_bad_values = [[1, [0, 0]]]
29+
30+
31+
def assertEqual(self, v1, v2):
32+
return np.testing.assert_array_equal(v1, v2)
33+
34+
35+
# Other test cases
36+
37+
38+
class TestArray(TestCase):
39+
40+
def test_array_equal(self):
41+
notifications = []
42+
43+
class Foo(HasTraits):
44+
bar = Array(default_value=[1, 2])
45+
@observe('bar')
46+
def _(self, change):
47+
notifications.append(change)
48+
49+
foo = Foo()
50+
foo.bar = [1, 2]
51+
self.assertFalse(len(notifications))
52+
foo.bar = [1, 1]
53+
self.assertTrue(len(notifications))

traittypes/traittypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,10 @@ def validate(self, obj, value):
1414
order=self.get_metadata('order'))
1515
except (ValueError, TypeError) as e:
1616
raise TraitError(e)
17+
18+
def set(self, obj, value):
19+
new_value = self._validate(obj, value)
20+
old_value = obj._trait_values.get(self.name, self.default_value)
21+
obj._trait_values[self.name] = new_value
22+
if not np.array_equal(old_value, new_value):
23+
obj._notify_trait(self.name, old_value, new_value)

0 commit comments

Comments
 (0)