@@ -184,12 +184,14 @@ class NumpyFloat(PythonFloat):
184
184
The argument passed to the function.
185
185
"""
186
186
__slots__ = ('_rank' ,'_shape' ,'_order' ,'_class_type' )
187
+ _static_type = NumpyFloat64Type ()
187
188
name = 'float'
189
+
188
190
def __init__ (self , arg ):
189
191
self ._shape = arg .shape
190
192
self ._rank = arg .rank
191
193
self ._order = arg .order
192
- self ._class_type = arg . class_type . switch_basic_type (self .static_type ())
194
+ self ._class_type = NumpyNDArrayType (self .static_type ()) if self . _rank else self . static_type ( )
193
195
super ().__init__ (arg )
194
196
195
197
@property
@@ -250,7 +252,7 @@ def __init__(self, arg):
250
252
self ._shape = arg .shape
251
253
self ._rank = arg .rank
252
254
self ._order = arg .order
253
- self ._class_type = arg . class_type . switch_basic_type (self .static_type ())
255
+ self ._class_type = NumpyNDArrayType (self .static_type ()) if self . _rank else self . static_type ( )
254
256
super ().__init__ (arg )
255
257
256
258
@property
@@ -277,12 +279,14 @@ class NumpyInt(PythonInt):
277
279
The argument passed to the function.
278
280
"""
279
281
__slots__ = ('_shape' ,'_rank' ,'_order' ,'_class_type' )
282
+ _static_type = numpy_precision_map [(PrimitiveIntegerType (), PythonInt ._static_type .precision )]
280
283
name = 'int'
284
+
281
285
def __init__ (self , arg = None , base = 10 ):
282
286
self ._shape = arg .shape
283
287
self ._rank = arg .rank
284
288
self ._order = arg .order
285
- self ._class_type = arg . class_type . switch_basic_type (self .static_type ())
289
+ self ._class_type = NumpyNDArrayType (self .static_type ()) if self . _rank else self . static_type ( )
286
290
super ().__init__ (arg )
287
291
288
292
@property
@@ -374,7 +378,10 @@ class NumpyReal(PythonReal):
374
378
name = 'real'
375
379
def __new__ (cls , arg ):
376
380
if isinstance (arg .dtype , PythonNativeBool ):
377
- return NumpyInt (arg )
381
+ if arg .rank :
382
+ return NumpyInt (arg )
383
+ else :
384
+ return PythonInt (arg )
378
385
else :
379
386
return super ().__new__ (cls , arg )
380
387
@@ -452,14 +459,16 @@ class NumpyComplex(PythonComplex):
452
459
_real_cast = NumpyReal
453
460
_imag_cast = NumpyImag
454
461
__slots__ = ('_rank' ,'_shape' ,'_order' ,'_class_type' )
462
+ _static_type = NumpyComplex128Type ()
455
463
name = 'complex'
464
+
456
465
def __init__ (self , arg0 , arg1 = None ):
457
466
if arg1 is not None :
458
467
raise NotImplementedError ("Use builtin complex function not deprecated np.complex" )
459
468
self ._shape = arg0 .shape
460
469
self ._rank = arg0 .rank
461
470
self ._order = arg0 .order
462
- self ._class_type = arg0 . class_type . switch_basic_type (self .static_type ())
471
+ self ._class_type = NumpyNDArrayType (self .static_type ()) if self . _rank else self . static_type ( )
463
472
super ().__init__ (arg0 )
464
473
465
474
@property
0 commit comments