25
25
import dpctl .utils
26
26
from dpctl .tensor ._device import normalize_queue_device
27
27
28
+ __doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"
29
+
28
30
_empty_tuple = tuple ()
29
31
_host_set = frozenset ([None ])
30
32
@@ -34,45 +36,42 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
34
36
if ref_type in [None , float ] or np .issubdtype (ref_type , np .floating ):
35
37
dtype = ti .default_device_fp_type (sycl_obj )
36
38
return dpt .dtype (dtype )
37
- elif ref_type in [bool , np .bool_ ]:
39
+ if ref_type in [bool , np .bool_ ]:
38
40
dtype = ti .default_device_bool_type (sycl_obj )
39
41
return dpt .dtype (dtype )
40
- elif ref_type is int or np .issubdtype (ref_type , np .integer ):
42
+ if ref_type is int or np .issubdtype (ref_type , np .integer ):
41
43
dtype = ti .default_device_int_type (sycl_obj )
42
44
return dpt .dtype (dtype )
43
- elif ref_type is complex or np .issubdtype (ref_type , np .complexfloating ):
45
+ if ref_type is complex or np .issubdtype (ref_type , np .complexfloating ):
44
46
dtype = ti .default_device_complex_type (sycl_obj )
45
47
return dpt .dtype (dtype )
46
- else :
47
- raise TypeError (f"Reference type { ref_type } not recognized." )
48
- else :
49
- return dpt .dtype (dtype )
48
+ raise TypeError (f"Reference type { ref_type } not recognized." )
49
+ return dpt .dtype (dtype )
50
50
51
51
52
52
def _array_info_dispatch (obj ):
53
53
if isinstance (obj , dpt .usm_ndarray ):
54
54
return obj .shape , obj .dtype , frozenset ([obj .sycl_queue ])
55
- elif isinstance (obj , np .ndarray ):
55
+ if isinstance (obj , np .ndarray ):
56
56
return obj .shape , obj .dtype , _host_set
57
- elif isinstance (obj , range ):
57
+ if isinstance (obj , range ):
58
58
return (len (obj ),), int , _host_set
59
- elif isinstance (obj , bool ):
59
+ if isinstance (obj , bool ):
60
60
return _empty_tuple , bool , _host_set
61
- elif isinstance (obj , float ):
61
+ if isinstance (obj , float ):
62
62
return _empty_tuple , float , _host_set
63
- elif isinstance (obj , int ):
63
+ if isinstance (obj , int ):
64
64
return _empty_tuple , int , _host_set
65
- elif isinstance (obj , complex ):
65
+ if isinstance (obj , complex ):
66
66
return _empty_tuple , complex , _host_set
67
- elif isinstance (obj , (list , tuple , range )):
67
+ if isinstance (obj , (list , tuple , range )):
68
68
return _array_info_sequence (obj )
69
- elif any (
69
+ if any (
70
70
isinstance (obj , s )
71
71
for s in [np .integer , np .floating , np .complexfloating , np .bool_ ]
72
72
):
73
73
return _empty_tuple , obj .dtype , _host_set
74
- else :
75
- raise ValueError (type (obj ))
74
+ raise ValueError (type (obj ))
76
75
77
76
78
77
def _array_info_sequence (li ):
@@ -91,9 +90,7 @@ def _array_info_sequence(li):
91
90
dt = np .promote_types (dt , el_dt )
92
91
device = device .union (el_dev )
93
92
else :
94
- raise ValueError (
95
- "Inconsistent dimensions, {} and {}" .format (dim , el_dim )
96
- )
93
+ raise ValueError (f"Inconsistent dimensions, { dim } and { el_dim } " )
97
94
if dim is None :
98
95
dim = tuple ()
99
96
dt = float
@@ -206,18 +203,18 @@ def _map_to_device_dtype(dt, q):
206
203
if np .issubdtype (dt , np .floating ):
207
204
if dtc == "f" :
208
205
return dt
209
- else :
210
- if dtc == "d" and d .has_aspect_fp64 :
211
- return dt
212
- if dtc == "h" and d .has_aspect_fp16 :
213
- return dt
214
- return dpt .dtype ("f4" )
215
- elif np .issubdtype (dt , np .complexfloating ):
206
+ if dtc == "d" and d .has_aspect_fp64 :
207
+ return dt
208
+ if dtc == "h" and d .has_aspect_fp16 :
209
+ return dt
210
+ return dpt .dtype ("f4" )
211
+ if np .issubdtype (dt , np .complexfloating ):
216
212
if dtc == "F" :
217
213
return dt
218
214
if dtc == "D" and d .has_aspect_fp64 :
219
215
return dt
220
216
return dpt .dtype ("c8" )
217
+ raise RuntimeError (f"Unrecognized data type '{ dt } ' encountered." )
221
218
222
219
223
220
def _asarray_from_numpy_ndarray (
@@ -349,8 +346,7 @@ def asarray(
349
346
raise ValueError (
350
347
"Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'."
351
348
)
352
- else :
353
- order = order [0 ].upper ()
349
+ order = order [0 ].upper ()
354
350
# 4. Check that usm_type is None, or a valid value
355
351
dpctl .utils .validate_usm_type (usm_type , allow_none = True )
356
352
# 5. Normalize device/sycl_queue [keep it None if was None]
@@ -369,7 +365,7 @@ def asarray(
369
365
sycl_queue = sycl_queue ,
370
366
order = order ,
371
367
)
372
- elif hasattr (obj , "__sycl_usm_array_interface__" ):
368
+ if hasattr (obj , "__sycl_usm_array_interface__" ):
373
369
sua_iface = getattr (obj , "__sycl_usm_array_interface__" )
374
370
membuf = dpm .as_usm_memory (obj )
375
371
ary = dpt .usm_ndarray (
@@ -386,7 +382,7 @@ def asarray(
386
382
sycl_queue = sycl_queue ,
387
383
order = order ,
388
384
)
389
- elif isinstance (obj , np .ndarray ):
385
+ if isinstance (obj , np .ndarray ):
390
386
if copy is False :
391
387
raise ValueError (
392
388
"Converting numpy.ndarray to usm_ndarray requires a copy"
@@ -398,7 +394,7 @@ def asarray(
398
394
sycl_queue = sycl_queue ,
399
395
order = order ,
400
396
)
401
- elif _is_object_with_buffer_protocol (obj ):
397
+ if _is_object_with_buffer_protocol (obj ):
402
398
if copy is False :
403
399
raise ValueError (
404
400
f"Converting { type (obj )} to usm_ndarray requires a copy"
@@ -410,12 +406,12 @@ def asarray(
410
406
sycl_queue = sycl_queue ,
411
407
order = order ,
412
408
)
413
- elif isinstance (obj , (list , tuple , range )):
409
+ if isinstance (obj , (list , tuple , range )):
414
410
if copy is False :
415
411
raise ValueError (
416
412
"Converting Python sequence to usm_ndarray requires a copy"
417
413
)
418
- _ , dt , devs = _array_info_sequence (obj )
414
+ _ , _ , devs = _array_info_sequence (obj )
419
415
if devs == _host_set :
420
416
return _asarray_from_numpy_ndarray (
421
417
np .asarray (obj , dtype = dtype , order = order ),
@@ -474,8 +470,7 @@ def empty(
474
470
raise ValueError (
475
471
"Unrecognized order keyword value, expecting 'F' or 'C'."
476
472
)
477
- else :
478
- order = order [0 ].upper ()
473
+ order = order [0 ].upper ()
479
474
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
480
475
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
481
476
dtype = _get_dtype (dtype , sycl_queue )
@@ -497,14 +492,13 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
497
492
dt = _get_dtype (dt , sycl_queue , ref_type = seq_dt )
498
493
if np .issubdtype (dt , np .integer ):
499
494
return tuple (int (v ) for v in args ), dt
500
- elif np .issubdtype (dt , np .floating ):
495
+ if np .issubdtype (dt , np .floating ):
501
496
return tuple (float (v ) for v in args ), dt
502
- elif np .issubdtype (dt , np .complexfloating ):
497
+ if np .issubdtype (dt , np .complexfloating ):
503
498
return tuple (complex (v ) for v in args ), dt
504
- elif allow_bool and dt .char == "?" :
499
+ if allow_bool and dt .char == "?" :
505
500
return tuple (bool (v ) for v in args ), dt
506
- else :
507
- raise ValueError (f"Data type { dt } is not supported" )
501
+ raise ValueError (f"Data type { dt } is not supported" )
508
502
509
503
510
504
def _round_for_arange (tmp ):
@@ -570,7 +564,7 @@ def arange(
570
564
is_bool = False
571
565
if dtype :
572
566
is_bool = (dtype is bool ) or (dpt .dtype (dtype ) == dpt .bool )
573
- ( start_ , stop_ , step_ ) , dt = _coerce_and_infer_dt (
567
+ _ , dt = _coerce_and_infer_dt (
574
568
start ,
575
569
stop ,
576
570
step ,
@@ -581,9 +575,7 @@ def arange(
581
575
)
582
576
try :
583
577
tmp = _get_arange_length (start , stop , step )
584
- sh = int (tmp )
585
- if sh < 0 :
586
- sh = 0
578
+ sh = max (int (tmp ), 0 )
587
579
except TypeError :
588
580
sh = 0
589
581
if is_bool and sh > 2 :
@@ -655,8 +647,7 @@ def zeros(
655
647
raise ValueError (
656
648
"Unrecognized order keyword value, expecting 'F' or 'C'."
657
649
)
658
- else :
659
- order = order [0 ].upper ()
650
+ order = order [0 ].upper ()
660
651
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
661
652
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
662
653
dtype = _get_dtype (dtype , sycl_queue )
@@ -703,8 +694,7 @@ def ones(
703
694
raise ValueError (
704
695
"Unrecognized order keyword value, expecting 'F' or 'C'."
705
696
)
706
- else :
707
- order = order [0 ].upper ()
697
+ order = order [0 ].upper ()
708
698
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
709
699
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
710
700
dtype = _get_dtype (dtype , sycl_queue )
@@ -715,7 +705,7 @@ def ones(
715
705
order = order ,
716
706
buffer_ctor_kwargs = {"queue" : sycl_queue },
717
707
)
718
- hev , ev = ti ._full_usm_ndarray (1 , res , sycl_queue )
708
+ hev , _ = ti ._full_usm_ndarray (1 , res , sycl_queue )
719
709
hev .wait ()
720
710
return res
721
711
@@ -759,8 +749,7 @@ def full(
759
749
raise ValueError (
760
750
"Unrecognized order keyword value, expecting 'F' or 'C'."
761
751
)
762
- else :
763
- order = order [0 ].upper ()
752
+ order = order [0 ].upper ()
764
753
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
765
754
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
766
755
dtype = _get_dtype (dtype , sycl_queue , ref_type = type (fill_value ))
@@ -771,7 +760,7 @@ def full(
771
760
order = order ,
772
761
buffer_ctor_kwargs = {"queue" : sycl_queue },
773
762
)
774
- hev , ev = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
763
+ hev , _ = ti ._full_usm_ndarray (fill_value , res , sycl_queue )
775
764
hev .wait ()
776
765
return res
777
766
@@ -811,8 +800,7 @@ def empty_like(
811
800
raise ValueError (
812
801
"Unrecognized order keyword value, expecting 'F' or 'C'."
813
802
)
814
- else :
815
- order = order [0 ].upper ()
803
+ order = order [0 ].upper ()
816
804
if dtype is None :
817
805
dtype = x .dtype
818
806
if usm_type is None :
@@ -868,8 +856,7 @@ def zeros_like(
868
856
raise ValueError (
869
857
"Unrecognized order keyword value, expecting 'F' or 'C'."
870
858
)
871
- else :
872
- order = order [0 ].upper ()
859
+ order = order [0 ].upper ()
873
860
if dtype is None :
874
861
dtype = x .dtype
875
862
if usm_type is None :
@@ -925,8 +912,7 @@ def ones_like(
925
912
raise ValueError (
926
913
"Unrecognized order keyword value, expecting 'F' or 'C'."
927
914
)
928
- else :
929
- order = order [0 ].upper ()
915
+ order = order [0 ].upper ()
930
916
if dtype is None :
931
917
dtype = x .dtype
932
918
if usm_type is None :
@@ -989,8 +975,7 @@ def full_like(
989
975
raise ValueError (
990
976
"Unrecognized order keyword value, expecting 'F' or 'C'."
991
977
)
992
- else :
993
- order = order [0 ].upper ()
978
+ order = order [0 ].upper ()
994
979
if dtype is None :
995
980
dtype = x .dtype
996
981
if usm_type is None :
@@ -1142,8 +1127,7 @@ def eye(
1142
1127
raise ValueError (
1143
1128
"Unrecognized order keyword value, expecting 'F' or 'C'."
1144
1129
)
1145
- else :
1146
- order = order [0 ].upper ()
1130
+ order = order [0 ].upper ()
1147
1131
n_rows = operator .index (n_rows )
1148
1132
n_cols = n_rows if n_cols is None else operator .index (n_cols )
1149
1133
k = operator .index (k )
@@ -1178,12 +1162,14 @@ def tril(X, k=0):
1178
1162
1179
1163
Returns the lower triangular part of a matrix (or a stack of matrices) X.
1180
1164
"""
1181
- if type (X ) is not dpt .usm_ndarray :
1182
- raise TypeError
1165
+ if not isinstance (X , dpt .usm_ndarray ):
1166
+ raise TypeError (
1167
+ "Expected argument of type dpctl.tensor.usm_ndarray, "
1168
+ f"got { type (X )} ."
1169
+ )
1183
1170
1184
1171
k = operator .index (k )
1185
1172
1186
- # F_CONTIGUOUS = 2
1187
1173
order = "F" if (X .flags .f_contiguous ) else "C"
1188
1174
1189
1175
shape = X .shape
@@ -1219,12 +1205,14 @@ def triu(X, k=0):
1219
1205
1220
1206
Returns the upper triangular part of a matrix (or a stack of matrices) X.
1221
1207
"""
1222
- if type (X ) is not dpt .usm_ndarray :
1223
- raise TypeError
1208
+ if not isinstance (X , dpt .usm_ndarray ):
1209
+ raise TypeError (
1210
+ "Expected argument of type dpctl.tensor.usm_ndarray, "
1211
+ f"got { type (X )} ."
1212
+ )
1224
1213
1225
1214
k = operator .index (k )
1226
1215
1227
- # F_CONTIGUOUS = 2
1228
1216
order = "F" if (X .flags .f_contiguous ) else "C"
1229
1217
1230
1218
shape = X .shape
0 commit comments