Skip to content

Commit 62c9be0

Browse files
Heeding most pylint warnings for _ctors.py
1 parent 5229880 commit 62c9be0

File tree

1 file changed

+57
-69
lines changed

1 file changed

+57
-69
lines changed

dpctl/tensor/_ctors.py

Lines changed: 57 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import dpctl.utils
2626
from dpctl.tensor._device import normalize_queue_device
2727

28+
__doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"
29+
2830
_empty_tuple = tuple()
2931
_host_set = frozenset([None])
3032

@@ -34,45 +36,42 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
3436
if ref_type in [None, float] or np.issubdtype(ref_type, np.floating):
3537
dtype = ti.default_device_fp_type(sycl_obj)
3638
return dpt.dtype(dtype)
37-
elif ref_type in [bool, np.bool_]:
39+
if ref_type in [bool, np.bool_]:
3840
dtype = ti.default_device_bool_type(sycl_obj)
3941
return dpt.dtype(dtype)
40-
elif ref_type is int or np.issubdtype(ref_type, np.integer):
42+
if ref_type is int or np.issubdtype(ref_type, np.integer):
4143
dtype = ti.default_device_int_type(sycl_obj)
4244
return dpt.dtype(dtype)
43-
elif ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
45+
if ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
4446
dtype = ti.default_device_complex_type(sycl_obj)
4547
return dpt.dtype(dtype)
46-
else:
47-
raise TypeError(f"Reference type {ref_type} not recognized.")
48-
else:
49-
return dpt.dtype(dtype)
48+
raise TypeError(f"Reference type {ref_type} not recognized.")
49+
return dpt.dtype(dtype)
5050

5151

5252
def _array_info_dispatch(obj):
5353
if isinstance(obj, dpt.usm_ndarray):
5454
return obj.shape, obj.dtype, frozenset([obj.sycl_queue])
55-
elif isinstance(obj, np.ndarray):
55+
if isinstance(obj, np.ndarray):
5656
return obj.shape, obj.dtype, _host_set
57-
elif isinstance(obj, range):
57+
if isinstance(obj, range):
5858
return (len(obj),), int, _host_set
59-
elif isinstance(obj, bool):
59+
if isinstance(obj, bool):
6060
return _empty_tuple, bool, _host_set
61-
elif isinstance(obj, float):
61+
if isinstance(obj, float):
6262
return _empty_tuple, float, _host_set
63-
elif isinstance(obj, int):
63+
if isinstance(obj, int):
6464
return _empty_tuple, int, _host_set
65-
elif isinstance(obj, complex):
65+
if isinstance(obj, complex):
6666
return _empty_tuple, complex, _host_set
67-
elif isinstance(obj, (list, tuple, range)):
67+
if isinstance(obj, (list, tuple, range)):
6868
return _array_info_sequence(obj)
69-
elif any(
69+
if any(
7070
isinstance(obj, s)
7171
for s in [np.integer, np.floating, np.complexfloating, np.bool_]
7272
):
7373
return _empty_tuple, obj.dtype, _host_set
74-
else:
75-
raise ValueError(type(obj))
74+
raise ValueError(type(obj))
7675

7776

7877
def _array_info_sequence(li):
@@ -91,9 +90,7 @@ def _array_info_sequence(li):
9190
dt = np.promote_types(dt, el_dt)
9291
device = device.union(el_dev)
9392
else:
94-
raise ValueError(
95-
"Inconsistent dimensions, {} and {}".format(dim, el_dim)
96-
)
93+
raise ValueError(f"Inconsistent dimensions, {dim} and {el_dim}")
9794
if dim is None:
9895
dim = tuple()
9996
dt = float
@@ -206,18 +203,18 @@ def _map_to_device_dtype(dt, q):
206203
if np.issubdtype(dt, np.floating):
207204
if dtc == "f":
208205
return dt
209-
else:
210-
if dtc == "d" and d.has_aspect_fp64:
211-
return dt
212-
if dtc == "h" and d.has_aspect_fp16:
213-
return dt
214-
return dpt.dtype("f4")
215-
elif np.issubdtype(dt, np.complexfloating):
206+
if dtc == "d" and d.has_aspect_fp64:
207+
return dt
208+
if dtc == "h" and d.has_aspect_fp16:
209+
return dt
210+
return dpt.dtype("f4")
211+
if np.issubdtype(dt, np.complexfloating):
216212
if dtc == "F":
217213
return dt
218214
if dtc == "D" and d.has_aspect_fp64:
219215
return dt
220216
return dpt.dtype("c8")
217+
raise RuntimeError(f"Unrecognized data type '{dt}' encountered.")
221218

222219

223220
def _asarray_from_numpy_ndarray(
@@ -349,8 +346,7 @@ def asarray(
349346
raise ValueError(
350347
"Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'."
351348
)
352-
else:
353-
order = order[0].upper()
349+
order = order[0].upper()
354350
# 4. Check that usm_type is None, or a valid value
355351
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
356352
# 5. Normalize device/sycl_queue [keep it None if was None]
@@ -369,7 +365,7 @@ def asarray(
369365
sycl_queue=sycl_queue,
370366
order=order,
371367
)
372-
elif hasattr(obj, "__sycl_usm_array_interface__"):
368+
if hasattr(obj, "__sycl_usm_array_interface__"):
373369
sua_iface = getattr(obj, "__sycl_usm_array_interface__")
374370
membuf = dpm.as_usm_memory(obj)
375371
ary = dpt.usm_ndarray(
@@ -386,7 +382,7 @@ def asarray(
386382
sycl_queue=sycl_queue,
387383
order=order,
388384
)
389-
elif isinstance(obj, np.ndarray):
385+
if isinstance(obj, np.ndarray):
390386
if copy is False:
391387
raise ValueError(
392388
"Converting numpy.ndarray to usm_ndarray requires a copy"
@@ -398,7 +394,7 @@ def asarray(
398394
sycl_queue=sycl_queue,
399395
order=order,
400396
)
401-
elif _is_object_with_buffer_protocol(obj):
397+
if _is_object_with_buffer_protocol(obj):
402398
if copy is False:
403399
raise ValueError(
404400
f"Converting {type(obj)} to usm_ndarray requires a copy"
@@ -410,12 +406,12 @@ def asarray(
410406
sycl_queue=sycl_queue,
411407
order=order,
412408
)
413-
elif isinstance(obj, (list, tuple, range)):
409+
if isinstance(obj, (list, tuple, range)):
414410
if copy is False:
415411
raise ValueError(
416412
"Converting Python sequence to usm_ndarray requires a copy"
417413
)
418-
_, dt, devs = _array_info_sequence(obj)
414+
_, _, devs = _array_info_sequence(obj)
419415
if devs == _host_set:
420416
return _asarray_from_numpy_ndarray(
421417
np.asarray(obj, dtype=dtype, order=order),
@@ -474,8 +470,7 @@ def empty(
474470
raise ValueError(
475471
"Unrecognized order keyword value, expecting 'F' or 'C'."
476472
)
477-
else:
478-
order = order[0].upper()
473+
order = order[0].upper()
479474
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
480475
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
481476
dtype = _get_dtype(dtype, sycl_queue)
@@ -497,14 +492,13 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
497492
dt = _get_dtype(dt, sycl_queue, ref_type=seq_dt)
498493
if np.issubdtype(dt, np.integer):
499494
return tuple(int(v) for v in args), dt
500-
elif np.issubdtype(dt, np.floating):
495+
if np.issubdtype(dt, np.floating):
501496
return tuple(float(v) for v in args), dt
502-
elif np.issubdtype(dt, np.complexfloating):
497+
if np.issubdtype(dt, np.complexfloating):
503498
return tuple(complex(v) for v in args), dt
504-
elif allow_bool and dt.char == "?":
499+
if allow_bool and dt.char == "?":
505500
return tuple(bool(v) for v in args), dt
506-
else:
507-
raise ValueError(f"Data type {dt} is not supported")
501+
raise ValueError(f"Data type {dt} is not supported")
508502

509503

510504
def _round_for_arange(tmp):
@@ -570,7 +564,7 @@ def arange(
570564
is_bool = False
571565
if dtype:
572566
is_bool = (dtype is bool) or (dpt.dtype(dtype) == dpt.bool)
573-
(start_, stop_, step_), dt = _coerce_and_infer_dt(
567+
_, dt = _coerce_and_infer_dt(
574568
start,
575569
stop,
576570
step,
@@ -581,9 +575,7 @@ def arange(
581575
)
582576
try:
583577
tmp = _get_arange_length(start, stop, step)
584-
sh = int(tmp)
585-
if sh < 0:
586-
sh = 0
578+
sh = max(int(tmp), 0)
587579
except TypeError:
588580
sh = 0
589581
if is_bool and sh > 2:
@@ -655,8 +647,7 @@ def zeros(
655647
raise ValueError(
656648
"Unrecognized order keyword value, expecting 'F' or 'C'."
657649
)
658-
else:
659-
order = order[0].upper()
650+
order = order[0].upper()
660651
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
661652
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
662653
dtype = _get_dtype(dtype, sycl_queue)
@@ -703,8 +694,7 @@ def ones(
703694
raise ValueError(
704695
"Unrecognized order keyword value, expecting 'F' or 'C'."
705696
)
706-
else:
707-
order = order[0].upper()
697+
order = order[0].upper()
708698
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
709699
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
710700
dtype = _get_dtype(dtype, sycl_queue)
@@ -715,7 +705,7 @@ def ones(
715705
order=order,
716706
buffer_ctor_kwargs={"queue": sycl_queue},
717707
)
718-
hev, ev = ti._full_usm_ndarray(1, res, sycl_queue)
708+
hev, _ = ti._full_usm_ndarray(1, res, sycl_queue)
719709
hev.wait()
720710
return res
721711

@@ -759,8 +749,7 @@ def full(
759749
raise ValueError(
760750
"Unrecognized order keyword value, expecting 'F' or 'C'."
761751
)
762-
else:
763-
order = order[0].upper()
752+
order = order[0].upper()
764753
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
765754
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
766755
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
@@ -771,7 +760,7 @@ def full(
771760
order=order,
772761
buffer_ctor_kwargs={"queue": sycl_queue},
773762
)
774-
hev, ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
763+
hev, _ = ti._full_usm_ndarray(fill_value, res, sycl_queue)
775764
hev.wait()
776765
return res
777766

@@ -811,8 +800,7 @@ def empty_like(
811800
raise ValueError(
812801
"Unrecognized order keyword value, expecting 'F' or 'C'."
813802
)
814-
else:
815-
order = order[0].upper()
803+
order = order[0].upper()
816804
if dtype is None:
817805
dtype = x.dtype
818806
if usm_type is None:
@@ -868,8 +856,7 @@ def zeros_like(
868856
raise ValueError(
869857
"Unrecognized order keyword value, expecting 'F' or 'C'."
870858
)
871-
else:
872-
order = order[0].upper()
859+
order = order[0].upper()
873860
if dtype is None:
874861
dtype = x.dtype
875862
if usm_type is None:
@@ -925,8 +912,7 @@ def ones_like(
925912
raise ValueError(
926913
"Unrecognized order keyword value, expecting 'F' or 'C'."
927914
)
928-
else:
929-
order = order[0].upper()
915+
order = order[0].upper()
930916
if dtype is None:
931917
dtype = x.dtype
932918
if usm_type is None:
@@ -989,8 +975,7 @@ def full_like(
989975
raise ValueError(
990976
"Unrecognized order keyword value, expecting 'F' or 'C'."
991977
)
992-
else:
993-
order = order[0].upper()
978+
order = order[0].upper()
994979
if dtype is None:
995980
dtype = x.dtype
996981
if usm_type is None:
@@ -1142,8 +1127,7 @@ def eye(
11421127
raise ValueError(
11431128
"Unrecognized order keyword value, expecting 'F' or 'C'."
11441129
)
1145-
else:
1146-
order = order[0].upper()
1130+
order = order[0].upper()
11471131
n_rows = operator.index(n_rows)
11481132
n_cols = n_rows if n_cols is None else operator.index(n_cols)
11491133
k = operator.index(k)
@@ -1178,12 +1162,14 @@ def tril(X, k=0):
11781162
11791163
Returns the lower triangular part of a matrix (or a stack of matrices) X.
11801164
"""
1181-
if type(X) is not dpt.usm_ndarray:
1182-
raise TypeError
1165+
if not isinstance(X, dpt.usm_ndarray):
1166+
raise TypeError(
1167+
"Expected argument of type dpctl.tensor.usm_ndarray, "
1168+
f"got {type(X)}."
1169+
)
11831170

11841171
k = operator.index(k)
11851172

1186-
# F_CONTIGUOUS = 2
11871173
order = "F" if (X.flags.f_contiguous) else "C"
11881174

11891175
shape = X.shape
@@ -1219,12 +1205,14 @@ def triu(X, k=0):
12191205
12201206
Returns the upper triangular part of a matrix (or a stack of matrices) X.
12211207
"""
1222-
if type(X) is not dpt.usm_ndarray:
1223-
raise TypeError
1208+
if not isinstance(X, dpt.usm_ndarray):
1209+
raise TypeError(
1210+
"Expected argument of type dpctl.tensor.usm_ndarray, "
1211+
f"got {type(X)}."
1212+
)
12241213

12251214
k = operator.index(k)
12261215

1227-
# F_CONTIGUOUS = 2
12281216
order = "F" if (X.flags.f_contiguous) else "C"
12291217

12301218
shape = X.shape

0 commit comments

Comments
 (0)