1- from traitlets import TraitType , TraitError
1+ from traitlets import TraitType , TraitError , Undefined
22import numpy as np
33
44
@@ -7,11 +7,13 @@ class Array(TraitType):
77 """A numpy array trait type."""
88
99 info_text = 'a numpy array'
10+ dtype = None
1011
1112 def validate (self , obj , value ):
13+ if value is None and not self .allow_none :
14+ self .error (obj , value )
1215 try :
13- return np .asarray (value , dtype = self .get_metadata ('dtype' ),
14- order = self .get_metadata ('order' ))
16+ return np .asarray (value , dtype = self .dtype )
1517 except (ValueError , TypeError ) as e :
1618 raise TraitError (e )
1719
@@ -21,3 +23,12 @@ def set(self, obj, value):
2123 obj ._trait_values [self .name ] = new_value
2224 if not np .array_equal (old_value , new_value ):
2325 obj ._notify_trait (self .name , old_value , new_value )
26+
27+ def __init__ (self , default_value = Undefined , allow_none = False ,
28+ dtype = None , ** kwargs ):
29+ self .dtype = dtype
30+ if default_value is Undefined :
31+ default_value = np .array (0 , dtype = self .dtype )
32+ elif default_value is not None :
33+ default_value = np .asarray (default_value , dtype = self .dtype )
34+ super (Array , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
0 commit comments