@@ -454,6 +454,7 @@ def empty(
454
454
455
455
456
456
def _coerce_and_infer_dt (* args , dt ):
457
+ "Deduce arange type from sequence spec"
457
458
nd , seq_dt , d = _array_info_sequence (args )
458
459
if d != _host_set or nd != (len (args ),):
459
460
raise ValueError ("start, stop and step must be Python scalars" )
@@ -470,6 +471,24 @@ def _coerce_and_infer_dt(*args, dt):
470
471
raise ValueError (f"Data type { dt } is not supported" )
471
472
472
473
474
+ def _get_arange_length (start , stop , step ):
475
+ "Compute length of arange sequence"
476
+ span = stop - start
477
+ if type (step ) in [int , float ] and type (span ) in [int , float ]:
478
+ offset = - 1 if step > 0 else 1
479
+ tmp = 1 + (span + offset ) / step
480
+ return tmp
481
+ tmp = span / step
482
+ if type (tmp ) is complex and tmp .imag == 0 :
483
+ tmp = tmp .real
484
+ else :
485
+ return tmp
486
+ k = int (tmp )
487
+ if k > 0 and float (k ) < tmp :
488
+ tmp = tmp + 1
489
+ return tmp
490
+
491
+
473
492
def arange (
474
493
start ,
475
494
/ ,
@@ -511,9 +530,7 @@ def arange(
511
530
step ,
512
531
), dt = _coerce_and_infer_dt (start , stop , step , dt = dtype )
513
532
try :
514
- tmp = 1 + (stop - start - 1 ) / step
515
- if type (tmp ) is complex and tmp .imag == 0.0 :
516
- tmp = tmp .real
533
+ tmp = _get_arange_length (start , stop , step )
517
534
sh = int (tmp )
518
535
if sh < 0 :
519
536
sh = 0
@@ -529,6 +546,7 @@ def arange(
529
546
buffer_ctor_kwargs = {"queue" : sycl_queue },
530
547
)
531
548
_step = (start + step ) - start
549
+ _step = dt .type (_step )
532
550
hev , _ = ti ._linspace_step (start , _step , res , sycl_queue )
533
551
hev .wait ()
534
552
return res
0 commit comments