1
+ import warnings
2
+
1
3
from traitlets import TraitType , TraitError , Undefined
2
4
3
5
class _DelayedImportError (object ):
@@ -22,6 +24,10 @@ class SciType(TraitType):
22
24
23
25
"""A base trait type for numpy arrays, pandas dataframes and series."""
24
26
27
+ def __init__ (self , ** kwargs ):
28
+ super (SciType , self ).__init__ (** kwargs )
29
+ self .validators = []
30
+
25
31
def valid (self , * validators ):
26
32
"""
27
33
Register new trait validators
@@ -59,6 +65,15 @@ class Foo(HasTraits):
59
65
self .validators .extend (validators )
60
66
return self
61
67
68
+ def validate (self , obj , value ):
69
+ """Validate the value against registered validators."""
70
+ try :
71
+ for validator in self .validators :
72
+ value = validator (self , value )
73
+ return value
74
+ except (ValueError , TypeError ) as e :
75
+ raise TraitError (e )
76
+
62
77
63
78
class Array (SciType ):
64
79
@@ -70,13 +85,20 @@ class Array(SciType):
70
85
def validate (self , obj , value ):
71
86
if value is None and not self .allow_none :
72
87
self .error (obj , value )
88
+ if value is None or value is Undefined :
89
+ return super (Array , self ).validate (obj , value )
73
90
try :
74
- value = np .asarray (value , dtype = self .dtype )
75
- for validator in self .validators :
76
- value = validator (self , value )
77
- return value
91
+ r = np .asarray (value , dtype = self .dtype )
92
+ if isinstance (value , np .ndarray ) and r is not value :
93
+ warnings .warn (
94
+ 'Given trait value dtype "%s" does not match required type "%s". '
95
+ 'A coerced copy has been created.' % (
96
+ np .dtype (value .dtype ).name ,
97
+ np .dtype (self .dtype ).name ))
98
+ value = r
78
99
except (ValueError , TypeError ) as e :
79
100
raise TraitError (e )
101
+ return super (Array , self ).validate (obj , value )
80
102
81
103
def set (self , obj , value ):
82
104
new_value = self ._validate (obj , value )
@@ -91,7 +113,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
91
113
default_value = np .array (0 , dtype = self .dtype )
92
114
elif default_value is not None :
93
115
default_value = np .asarray (default_value , dtype = self .dtype )
94
- self .validators = []
95
116
super (Array , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
96
117
97
118
def make_dynamic_default (self ):
@@ -110,13 +131,13 @@ class DataFrame(SciType):
110
131
def validate (self , obj , value ):
111
132
if value is None and not self .allow_none :
112
133
self .error (obj , value )
134
+ if value is None or value is Undefined :
135
+ return super (DataFrame , self ).validate (obj , value )
113
136
try :
114
137
value = pd .DataFrame (value )
115
- for validator in self .validators :
116
- value = validator (self , value )
117
- return value
118
138
except (ValueError , TypeError ) as e :
119
139
raise TraitError (e )
140
+ return super (DataFrame , self ).validate (obj , value )
120
141
121
142
def set (self , obj , value ):
122
143
new_value = self ._validate (obj , value )
@@ -132,7 +153,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
132
153
default_value = pd .DataFrame ()
133
154
elif default_value is not None :
134
155
default_value = pd .DataFrame (default_value )
135
- self .validators = []
136
156
super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
137
157
138
158
def make_dynamic_default (self ):
@@ -151,13 +171,13 @@ class Series(SciType):
151
171
def validate (self , obj , value ):
152
172
if value is None and not self .allow_none :
153
173
self .error (obj , value )
174
+ if value is None or value is Undefined :
175
+ return super (Series , self ).validate (obj , value )
154
176
try :
155
177
value = pd .Series (value )
156
- for validator in self .validators :
157
- value = validator (self , value )
158
- return value
159
178
except (ValueError , TypeError ) as e :
160
179
raise TraitError (e )
180
+ return super (Series , self ).validate (obj , value )
161
181
162
182
def set (self , obj , value ):
163
183
new_value = self ._validate (obj , value )
@@ -173,7 +193,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
173
193
default_value = pd .Series ()
174
194
elif default_value is not None :
175
195
default_value = pd .Series (default_value )
176
- self .validators = []
177
196
super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
178
197
179
198
def make_dynamic_default (self ):
0 commit comments