Skip to content

Commit 5a1a8f9

Browse files
committed
Tweak validation code
1 parent 95d7369 commit 5a1a8f9

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

traittypes/traittypes.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from traitlets import TraitType, TraitError, Undefined
24

35
class _DelayedImportError(object):
@@ -22,6 +24,10 @@ class SciType(TraitType):
2224

2325
"""A base trait type for numpy arrays, pandas dataframes and series."""
2426

27+
def __init__(self, **kwargs):
28+
super(SciType, self).__init__(**kwargs)
29+
self.validators = []
30+
2531
def valid(self, *validators):
2632
"""
2733
Register new trait validators
@@ -59,6 +65,15 @@ class Foo(HasTraits):
5965
self.validators.extend(validators)
6066
return self
6167

68+
def validate(self, obj, value):
69+
"""Validate the value against registered validators."""
70+
try:
71+
for validator in self.validators:
72+
value = validator(self, value)
73+
return value
74+
except (ValueError, TypeError) as e:
75+
raise TraitError(e)
76+
6277

6378
class Array(SciType):
6479

@@ -70,13 +85,20 @@ class Array(SciType):
7085
def validate(self, obj, value):
7186
if value is None and not self.allow_none:
7287
self.error(obj, value)
88+
if value is None or value is Undefined:
89+
return super(Array, self).validate(obj, value)
7390
try:
74-
value = np.asarray(value, dtype=self.dtype)
75-
for validator in self.validators:
76-
value = validator(self, value)
77-
return value
91+
r = np.asarray(value, dtype=self.dtype)
92+
if isinstance(value, np.ndarray) and r is not value:
93+
warnings.warn(
94+
'Given trait value dtype "%s" does not match required type "%s". '
95+
'A coerced copy has been created.' % (
96+
np.dtype(value.dtype).name,
97+
np.dtype(self.dtype).name))
98+
value = r
7899
except (ValueError, TypeError) as e:
79100
raise TraitError(e)
101+
return super(Array, self).validate(obj, value)
80102

81103
def set(self, obj, value):
82104
new_value = self._validate(obj, value)
@@ -91,7 +113,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
91113
default_value = np.array(0, dtype=self.dtype)
92114
elif default_value is not None:
93115
default_value = np.asarray(default_value, dtype=self.dtype)
94-
self.validators = []
95116
super(Array, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
96117

97118
def make_dynamic_default(self):
@@ -110,13 +131,13 @@ class DataFrame(SciType):
110131
def validate(self, obj, value):
111132
if value is None and not self.allow_none:
112133
self.error(obj, value)
134+
if value is None or value is Undefined:
135+
return super(DataFrame, self).validate(obj, value)
113136
try:
114137
value = pd.DataFrame(value)
115-
for validator in self.validators:
116-
value = validator(self, value)
117-
return value
118138
except (ValueError, TypeError) as e:
119139
raise TraitError(e)
140+
return super(DataFrame, self).validate(obj, value)
120141

121142
def set(self, obj, value):
122143
new_value = self._validate(obj, value)
@@ -132,7 +153,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
132153
default_value = pd.DataFrame()
133154
elif default_value is not None:
134155
default_value = pd.DataFrame(default_value)
135-
self.validators = []
136156
super(DataFrame, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
137157

138158
def make_dynamic_default(self):
@@ -151,13 +171,13 @@ class Series(SciType):
151171
def validate(self, obj, value):
152172
if value is None and not self.allow_none:
153173
self.error(obj, value)
174+
if value is None or value is Undefined:
175+
return super(Series, self).validate(obj, value)
154176
try:
155177
value = pd.Series(value)
156-
for validator in self.validators:
157-
value = validator(self, value)
158-
return value
159178
except (ValueError, TypeError) as e:
160179
raise TraitError(e)
180+
return super(Series, self).validate(obj, value)
161181

162182
def set(self, obj, value):
163183
new_value = self._validate(obj, value)
@@ -173,7 +193,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
173193
default_value = pd.Series()
174194
elif default_value is not None:
175195
default_value = pd.Series(default_value)
176-
self.validators = []
177196
super(Series, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
178197

179198
def make_dynamic_default(self):

0 commit comments

Comments
 (0)