|
1 | 1 | from traitlets import TraitType, TraitError, Undefined
|
2 | 2 | import numpy as np
|
| 3 | +import pandas as pd |
3 | 4 |
|
4 | 5 |
|
5 |
| -class Array(TraitType): |
| 6 | +class SciType(TraitType): |
| 7 | + |
| 8 | + """A base traittype for numpy arrays, pandas dataframes and series.""" |
| 9 | + |
| 10 | + def valid(self, *validators): |
| 11 | + """ |
| 12 | + Register new trait validators |
| 13 | +
|
| 14 | + Validators are functions that take two arguments. |
| 15 | + - The trait instance |
| 16 | + - The proposed value |
| 17 | +
|
| 18 | + Validators return the (potentially modified) value, which is either |
| 19 | + assigned to the HasTraits attribute or input into the next validator. |
| 20 | +
|
| 21 | + They are evaluated in the order in which they are provided to the `valid` |
| 22 | + function. |
| 23 | +
|
| 24 | + Example |
| 25 | + ------- |
| 26 | +
|
| 27 | + .. code-block:: python |
| 28 | + # Test with a shape constraint |
| 29 | + def shape(*dimensions): |
| 30 | + def validator(trait, value): |
| 31 | + if value.shape != dimensions: |
| 32 | + raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape)) |
| 33 | + else: |
| 34 | + return value |
| 35 | + return validator |
| 36 | +
|
| 37 | + class Foo(HasTraits): |
| 38 | + bar = Array(np.identity(2)).valid(shape(2, 2)) |
| 39 | + foo = Foo() |
| 40 | +
|
| 41 | + foo.bar = [1, 2] # Should raise a TraitError |
| 42 | + """ |
| 43 | + self.validators.extend(validators) |
| 44 | + return self |
| 45 | + |
| 46 | + |
| 47 | +class Array(SciType): |
6 | 48 |
|
7 | 49 | """A numpy array trait type."""
|
8 | 50 |
|
@@ -36,38 +78,70 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
|
36 | 78 | self.validators = []
|
37 | 79 | super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
|
38 | 80 |
|
39 |
| - def valid(self, *validators): |
40 |
| - """ |
41 |
| - Register new trait validators |
42 | 81 |
|
43 |
| - Validators are functions that take two arguments. |
44 |
| - - The trait instance |
45 |
| - - The proposed value |
| 82 | +class DataFrame(SciType): |
46 | 83 |
|
47 |
| - Validators return the (potentially modified) value, which is either |
48 |
| - assigned to the HasTraits attribute or input into the next validator. |
| 84 | + """A pandas dataframe trait type.""" |
49 | 85 |
|
50 |
| - They are evaluated in the order in which they are provided to the `valid` |
51 |
| - function. |
| 86 | + info_text = 'a pandas dataframe' |
52 | 87 |
|
53 |
| - Example |
54 |
| - ------- |
| 88 | + def validate(self, obj, value): |
| 89 | + if value is None and not self.allow_none: |
| 90 | + self.error(obj, value) |
| 91 | + try: |
| 92 | + value = pd.DataFrame(value) |
| 93 | + for validator in self.validators: |
| 94 | + value = validator(self, value) |
| 95 | + return value |
| 96 | + except (ValueError, TypeError) as e: |
| 97 | + raise TraitError(e) |
55 | 98 |
|
56 |
| - .. code-block:: python |
57 |
| - # Test with a shape constraint |
58 |
| - def shape(*dimensions): |
59 |
| - def validator(trait, value): |
60 |
| - if value.shape != dimensions: |
61 |
| - raise TraitError('Expected an of shape %s and got and array with shape %s' % (dimensions, value.shape)) |
62 |
| - else: |
63 |
| - return value |
64 |
| - return validator |
| 99 | + def set(self, obj, value): |
| 100 | + new_value = self._validate(obj, value) |
| 101 | + old_value = obj._trait_values.get(self.name, self.default_value) |
| 102 | + obj._trait_values[self.name] = new_value |
| 103 | + if (old_value is None and new_value is not None) or not old_value.equals(new_value): |
| 104 | + obj._notify_trait(self.name, old_value, new_value) |
65 | 105 |
|
66 |
| - class Foo(HasTraits): |
67 |
| - bar = Array(np.identity(2)).valid(shape(2, 2)) |
68 |
| - foo = Foo() |
| 106 | + def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs): |
| 107 | + self.dtype = dtype |
| 108 | + if default_value is Undefined: |
| 109 | + default_value = pd.DataFrame() |
| 110 | + elif default_value is not None: |
| 111 | + default_value = pd.DataFrame(default_value) |
| 112 | + self.validators = [] |
| 113 | + super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs) |
69 | 114 |
|
70 |
| - foo.bar = [1, 2] # Should raise a TraitError |
71 |
| - """ |
72 |
| - self.validators.extend(validators) |
73 |
| - return self |
| 115 | + |
| 116 | +class Series(SciType): |
| 117 | + |
| 118 | + """A pandas series trait type.""" |
| 119 | + |
| 120 | + info_text = 'a pandas series' |
| 121 | + |
| 122 | + def validate(self, obj, value): |
| 123 | + if value is None and not self.allow_none: |
| 124 | + self.error(obj, value) |
| 125 | + try: |
| 126 | + value = pd.Series(value) |
| 127 | + for validator in self.validators: |
| 128 | + value = validator(self, value) |
| 129 | + return value |
| 130 | + except (ValueError, TypeError) as e: |
| 131 | + raise TraitError(e) |
| 132 | + |
| 133 | + def set(self, obj, value): |
| 134 | + new_value = self._validate(obj, value) |
| 135 | + old_value = obj._trait_values.get(self.name, self.default_value) |
| 136 | + obj._trait_values[self.name] = new_value |
| 137 | + if (old_value is None and new_value is not None) or not old_value.equals(new_value): |
| 138 | + obj._notify_trait(self.name, old_value, new_value) |
| 139 | + |
| 140 | + def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwargs): |
| 141 | + self.dtype = dtype |
| 142 | + if default_value is Undefined: |
| 143 | + default_value = pd.Series() |
| 144 | + elif default_value is not None: |
| 145 | + default_value = pd.Series(default_value) |
| 146 | + self.validators = [] |
| 147 | + super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs) |
0 commit comments