Skip to content

Commit 5aa0fd6

Browse files
committed
Add custom validation ot Array trait type
1 parent 85cdccf commit 5aa0fd6

File tree

3 files changed

+125
-9
lines changed

3 files changed

+125
-9
lines changed

README.md

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,63 @@ Trait types for NumPy, SciPy and friends
66

77
## Goals
88

9-
Provide a reference implementation of trait types for common data structures used in the scipy stack such as
9+
Provide a reference implementation of trait types for common data structures
10+
used in the scipy stack such as
1011
- numpy arrays
1112
- pandas / xray data structures
1213

13-
which are out of the scope of the main [traitlets](https://github.com/ipython/traitlets) project but are a common requirement to build applications with traitlets in combination with the scipy stack.
14+
which are out of the scope of the main [traitlets](https://github.com/ipython/traitlets)
15+
project but are a common requirement to build applications with traitlets in
16+
combination with the scipy stack.
1417

15-
Another goal is to create adequate serialization and deserialization routines for these trait types to be used with the [ipywidgets](https://github.com/ipython/ipywidgets) project (`to_json` and `from_json`). These could also return a list of binary buffers as allowed by the current message protocol.
18+
Another goal is to create adequate serialization and deserialization routines
19+
for these trait types to be used with the [ipywidgets](https://github.com/ipython/ipywidgets)
20+
project (`to_json` and `from_json`). These could also return a list of binary
21+
buffers as allowed by the current messaging protocol.
1622

1723
## Installation
1824

19-
For a local installation, make sure you have
20-
[pip installed](https://pip.readthedocs.org/en/stable/installing/) and run:
25+
26+
Using `pip`:
27+
28+
Make sure you have [pip installed](https://pip.readthedocs.org/en/stable/installing/) and run:
2129

2230
```
2331
pip install traittypes
2432
```
33+
34+
Using `conda`:
35+
36+
```
37+
conda install -c conda-forge traittypes
38+
```
39+
40+
## Usage
41+
42+
The `Array` trait type provide an implementation of a trait type for the numpy
43+
array.
44+
- `Array` overrides some methods from `TraiType` that are generally not
45+
overloaded in order to work around some limitations with numpy array
46+
comparison.
47+
- `Array` provides an API for adding custom validators to constained proposed
48+
values for the attribute.
49+
50+
```python
51+
from traitlets import HasTraits, TraitError
52+
from traittypes import Array
53+
54+
def shape(*dimensions):
55+
def validator(trait, value):
56+
if value.shape != dimensions:
57+
raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape))
58+
else:
59+
return value
60+
return validator
61+
62+
class Foo(HasTraits):
63+
bar = Array(np.identity(2)).valid(shape(2, 2))
64+
foo = Foo()
65+
66+
foo.bar = [1, 2] # Should raise a TraitError
67+
```
68+

traittypes/tests/test_traittypes.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class TestIntArray(TraitTestBase):
2727
_good_values = [1, [1, 2, 3], [[1, 2, 3], [4, 5, 6]], np.array([1])]
2828
_bad_values = [[1, [0, 0]]]
2929

30-
3130
def assertEqual(self, v1, v2):
3231
return np.testing.assert_array_equal(v1, v2)
3332

@@ -70,3 +69,37 @@ class Foo(HasTraits):
7069
with self.assertRaises(TraitError):
7170
foo.bar = None
7271
foo.baz = None
72+
73+
def test_custom_validators(self):
74+
# Test with a squeeze coercion
75+
def squeeze(trait, value):
76+
if 1 in value.shape:
77+
value = np.squeeze(value)
78+
return value
79+
80+
class Foo(HasTraits):
81+
bar = Array().valid(squeeze)
82+
83+
foo = Foo(bar=[[1], [2]])
84+
self.assertTrue(np.array_equal(foo.bar, [1, 2]))
85+
foo.bar = [[1], [2], [3]]
86+
self.assertTrue(np.array_equal(foo.bar, [1, 2, 3]))
87+
88+
# Test with a shape constraint
89+
def shape(*dimensions):
90+
def validator(trait, value):
91+
if value.shape != dimensions:
92+
raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape))
93+
else:
94+
return value
95+
return validator
96+
97+
class Foo(HasTraits):
98+
bar = Array(np.identity(2)).valid(shape(2, 2))
99+
foo = Foo()
100+
with self.assertRaises(TraitError):
101+
foo.bar = [1]
102+
new_value = [[0, 1], [1, 0]]
103+
foo.bar = new_value
104+
self.assertTrue(np.array_equal(foo.bar, new_value))
105+

traittypes/traittypes.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ def validate(self, obj, value):
1313
if value is None and not self.allow_none:
1414
self.error(obj, value)
1515
try:
16-
return np.asarray(value, dtype=self.dtype)
16+
value = np.asarray(value, dtype=self.dtype)
17+
for validator in self.validators:
18+
value = validator(self, value)
19+
return value
1720
except (ValueError, TypeError) as e:
1821
raise TraitError(e)
1922

@@ -24,11 +27,47 @@ def set(self, obj, value):
2427
if not np.array_equal(old_value, new_value):
2528
obj._notify_trait(self.name, old_value, new_value)
2629

27-
def __init__(self, default_value=Undefined, allow_none=False,
28-
dtype=None, **kwargs):
30+
def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs):
2931
self.dtype = dtype
3032
if default_value is Undefined:
3133
default_value = np.array(0, dtype=self.dtype)
3234
elif default_value is not None:
3335
default_value = np.asarray(default_value, dtype=self.dtype)
36+
self.validators = []
3437
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
38+
39+
def valid(self, *validators):
40+
"""
41+
Register new trait validators
42+
43+
Validators are functions that take two arguments.
44+
- The trait instance
45+
- The proposed value
46+
47+
Validators return the (potentially modified) value, which is either
48+
assigned to the HasTraits attribute or input into the next validator.
49+
50+
They are evaluated in the order in which they are provided to the `valid`
51+
function.
52+
53+
Example
54+
-------
55+
56+
.. code-block:: python
57+
# Test with a shape constraint
58+
def shape(*dimensions):
59+
def validator(trait, value):
60+
if value.shape != dimensions:
61+
raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape))
62+
else:
63+
return value
64+
return validator
65+
66+
class Foo(HasTraits):
67+
bar = Array(np.identity(2)).valid(shape(2, 2))
68+
foo = Foo()
69+
70+
foo.bar = [1, 2] # Should raise a TraitError
71+
"""
72+
self.validators.extend(validators)
73+
return self

0 commit comments

Comments
 (0)