2121from traceback import format_exception_only
2222from collections import namedtuple , OrderedDict
2323from itertools import chain , count , starmap
24- from typing import List , Dict , Any , Mapping
24+ from typing import List , Dict , Any , Mapping , Optional
2525
2626import numpy as np
2727
@@ -1021,6 +1021,7 @@ def bind_variable(descriptor, env, data, use_values):
10211021
10221022 values = {}
10231023 cast = None
1024+ dtype = object if isinstance (descriptor , StringDescriptor ) else float
10241025
10251026 if isinstance (descriptor , DiscreteDescriptor ):
10261027 if not descriptor .values :
@@ -1038,7 +1039,7 @@ def bind_variable(descriptor, env, data, use_values):
10381039 cast = DateTimeCast ()
10391040
10401041 func = FeatureFunc (descriptor .expression , source_vars , values , cast ,
1041- use_values = use_values )
1042+ use_values = use_values , dtype = dtype )
10421043 return descriptor , func
10431044
10441045
@@ -1216,7 +1217,10 @@ class FeatureFunc:
12161217 A function for casting the expressions result to the appropriate
12171218 type (e.g. string representation of date/time variables to floats)
12181219 """
1219- def __init__ (self , expression , args , extra_env = None , cast = None , use_values = False ):
1220+ dtype : Optional ['DType' ] = None
1221+
1222+ def __init__ (self , expression , args , extra_env = None , cast = None , use_values = False ,
1223+ dtype = None ):
12201224 self .expression = expression
12211225 self .args = args
12221226 self .extra_env = dict (extra_env or {})
@@ -1225,6 +1229,7 @@ def __init__(self, expression, args, extra_env=None, cast=None, use_values=False
12251229 self .cast = cast
12261230 self .mask_exceptions = True
12271231 self .use_values = use_values
1232+ self .dtype = dtype
12281233
12291234 def __call__ (self , table , * _ ):
12301235 if isinstance (table , Table ):
@@ -1252,7 +1257,7 @@ def __call_table(self, table):
12521257 y = list (starmap (f , args ))
12531258 if self .cast is not None :
12541259 y = self .cast (y )
1255- return y
1260+ return np . asarray ( y , dtype = self . dtype )
12561261
12571262 def __call_instance (self , instance : Instance ):
12581263 table = Table .from_numpy (
@@ -1281,7 +1286,8 @@ def extract_column(self, table: Table, var: Variable):
12811286
12821287 def __reduce__ (self ):
12831288 return type (self ), (self .expression , self .args ,
1284- self .extra_env , self .cast , self .use_values )
1289+ self .extra_env , self .cast , self .use_values ,
1290+ self .dtype )
12851291
12861292 def __repr__ (self ):
12871293 return "{0.__name__}{1!r}" .format (* self .__reduce__ ())
0 commit comments