Skip to content

Commit 56a03e8

Browse files
authored
Remove data synchronize where possible (#1930)
* Remove data sync from arraycreation and histogram functions * Remove data sync from elementwise functions * Remove data sync from manipulation functions
1 parent 4f2e300 commit 56a03e8

File tree

7 files changed

+34
-106
lines changed

7 files changed

+34
-106
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def __call__(
179179
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
180180
res_usm = super().__call__(x_usm, out=out_usm, order=order)
181181

182-
dpnp.synchronize_array_data(res_usm)
183182
if out is not None and isinstance(out, dpnp_array):
184183
return out
185184
return dpnp_array._create_from_usm_ndarray(res_usm)
@@ -352,7 +351,6 @@ def __call__(
352351
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
353352
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)
354353

355-
dpnp.synchronize_array_data(res_usm)
356354
if out is not None and isinstance(out, dpnp_array):
357355
return out
358356
return dpnp_array._create_from_usm_ndarray(res_usm)
@@ -540,7 +538,6 @@ def __call__(self, x, decimals=0, out=None, dtype=None):
540538
if dtype is not None:
541539
res_usm = dpt.astype(res_usm, dtype, copy=False)
542540

543-
dpnp.synchronize_array_data(res_usm)
544541
if out is not None and isinstance(out, dpnp_array):
545542
return out
546543
return dpnp_array._create_from_usm_ndarray(res_usm)

dpnp/dpnp_array.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,6 @@ def __getitem__(self, key):
257257

258258
res = self.__new__(dpnp_array)
259259
res._array_obj = item
260-
261-
if self._array_obj.usm_data is not res._array_obj.usm_data:
262-
dpnp.synchronize_array_data(self)
263260
return res
264261

265262
def __gt__(self, other):
@@ -456,7 +453,6 @@ def __setitem__(self, key, val):
456453
val = val.get_array()
457454

458455
self._array_obj.__setitem__(key, val)
459-
dpnp.synchronize_array_data(self)
460456

461457
# '__setstate__',
462458
# '__sizeof__',

dpnp/dpnp_container.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def arange(
8080
sycl_queue=sycl_queue_normalized,
8181
)
8282

83-
dpnp.synchronize_array_data(array_obj)
8483
return dpnp_array(array_obj.shape, buffer=array_obj)
8584

8685

@@ -133,7 +132,6 @@ def asarray(
133132
if array_obj is x1_obj and isinstance(x1, dpnp_array):
134133
return x1
135134

136-
dpnp.synchronize_array_data(array_obj)
137135
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
138136

139137

@@ -143,7 +141,6 @@ def copy(x1, /, *, order="K"):
143141
order = "K"
144142

145143
array_obj = dpt.copy(dpnp.get_usm_ndarray(x1), order=order)
146-
dpnp.synchronize_array_data(array_obj)
147144
return dpnp_array(array_obj.shape, buffer=array_obj, order="K")
148145

149146

@@ -205,7 +202,6 @@ def eye(
205202
usm_type=usm_type,
206203
sycl_queue=sycl_queue_normalized,
207204
)
208-
dpnp.synchronize_array_data(array_obj)
209205
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
210206

211207

@@ -240,7 +236,6 @@ def full(
240236
usm_type=usm_type,
241237
sycl_queue=sycl_queue_normalized,
242238
)
243-
dpnp.synchronize_array_data(array_obj)
244239
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
245240

246241

@@ -269,21 +264,18 @@ def ones(
269264
usm_type=usm_type,
270265
sycl_queue=sycl_queue_normalized,
271266
)
272-
dpnp.synchronize_array_data(array_obj)
273267
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
274268

275269

276270
def tril(x1, /, *, k=0):
277271
"""Creates `dpnp_array` as lower triangular part of an input array."""
278272
array_obj = dpt.tril(dpnp.get_usm_ndarray(x1), k=k)
279-
dpnp.synchronize_array_data(array_obj)
280273
return dpnp_array(array_obj.shape, buffer=array_obj, order="K")
281274

282275

283276
def triu(x1, /, *, k=0):
284277
"""Creates `dpnp_array` as upper triangular part of an input array."""
285278
array_obj = dpt.triu(dpnp.get_usm_ndarray(x1), k=k)
286-
dpnp.synchronize_array_data(array_obj)
287279
return dpnp_array(array_obj.shape, buffer=array_obj, order="K")
288280

289281

@@ -312,6 +304,4 @@ def zeros(
312304
usm_type=usm_type,
313305
sycl_queue=sycl_queue_normalized,
314306
)
315-
# TODO: uncomment once dpctl implements asynchronous call
316-
# dpnp.synchronize_array_data(array_obj)
317307
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)

dpnp/dpnp_iface.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True, device=None):
284284
x1_obj, dtype, order=order, casting=casting, copy=copy, device=device
285285
)
286286

287-
dpnp.synchronize_array_data(x1)
288287
if array_obj is x1_obj and isinstance(x1, dpnp_array):
289288
# return x1 if dpctl returns a zero copy of x1_obj
290289
return x1
@@ -797,6 +796,5 @@ def synchronize_array_data(a):
797796
798797
"""
799798

800-
if hasattr(dpu, "SequentialOrderManager"):
801-
check_supported_arrays_type(a)
802-
dpu.SequentialOrderManager[a.sycl_queue].wait()
799+
check_supported_arrays_type(a)
800+
dpu.SequentialOrderManager[a.sycl_queue].wait()

dpnp/dpnp_iface_arraycreation.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,7 @@ def geomspace(
21882188
21892189
"""
21902190

2191-
res = dpnp_geomspace(
2191+
return dpnp_geomspace(
21922192
start,
21932193
stop,
21942194
num,
@@ -2200,9 +2200,6 @@ def geomspace(
22002200
axis=axis,
22012201
)
22022202

2203-
dpnp.synchronize_array_data(res)
2204-
return res
2205-
22062203

22072204
def identity(
22082205
n,
@@ -2410,7 +2407,7 @@ def linspace(
24102407
24112408
"""
24122409

2413-
res = dpnp_linspace(
2410+
return dpnp_linspace(
24142411
start,
24152412
stop,
24162413
num,
@@ -2423,12 +2420,6 @@ def linspace(
24232420
axis=axis,
24242421
)
24252422

2426-
if isinstance(res, tuple): # (result, step) is returning
2427-
dpnp.synchronize_array_data(res[0])
2428-
else:
2429-
dpnp.synchronize_array_data(res)
2430-
return res
2431-
24322423

24332424
def loadtxt(
24342425
fname,
@@ -2643,7 +2634,7 @@ def logspace(
26432634
26442635
"""
26452636

2646-
res = dpnp_logspace(
2637+
return dpnp_logspace(
26472638
start,
26482639
stop,
26492640
num=num,
@@ -2656,9 +2647,6 @@ def logspace(
26562647
axis=axis,
26572648
)
26582649

2659-
dpnp.synchronize_array_data(res)
2660-
return res
2661-
26622650

26632651
# pylint: disable=redefined-outer-name
26642652
def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
@@ -2759,9 +2747,7 @@ def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
27592747
if copy:
27602748
output = [dpt.copy(x) for x in output]
27612749

2762-
dpnp.synchronize_array_data(output[0])
2763-
output = [dpnp_array._create_from_usm_ndarray(x) for x in output]
2764-
return output
2750+
return [dpnp_array._create_from_usm_ndarray(x) for x in output]
27652751

27662752

27672753
class MGridClass:

dpnp/dpnp_iface_histograms.py

Lines changed: 28 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,11 @@
4040
import operator
4141
import warnings
4242

43-
import dpctl.tensor as dpt
4443
import dpctl.utils as dpu
4544
import numpy
4645

4746
import dpnp
4847

49-
from .dpnp_algo.dpnp_arraycreation import (
50-
dpnp_linspace,
51-
)
52-
from .dpnp_array import dpnp_array
53-
5448
__all__ = [
5549
"digitize",
5650
"histogram",
@@ -63,10 +57,14 @@
6357

6458

6559
def _ravel_check_a_and_weights(a, weights):
66-
"""Check input `a` and `weights` arrays, and ravel both."""
60+
"""
61+
Check input `a` and `weights` arrays, and ravel both.
62+
The returned array have :class:`dpnp.ndarray` type always.
63+
64+
"""
6765

6866
# ensure that `a` array has supported type
69-
a = dpnp.get_usm_ndarray(a)
67+
dpnp.check_supported_arrays_type(a)
7068
usm_type = a.usm_type
7169

7270
# ensure that the array is a "subtractable" dtype
@@ -77,11 +75,11 @@ def _ravel_check_a_and_weights(a, weights):
7775
RuntimeWarning,
7876
stacklevel=3,
7977
)
80-
a = dpt.astype(a, numpy.uint8)
78+
a = dpnp.astype(a, numpy.uint8)
8179

8280
if weights is not None:
8381
# check that `weights` array has supported type
84-
weights = dpnp.get_usm_ndarray(weights)
82+
dpnp.check_supported_arrays_type(weights)
8583
usm_type = dpu.get_coerced_usm_type([usm_type, weights.usm_type])
8684

8785
# check that arrays have the same allocation queue
@@ -92,9 +90,9 @@ def _ravel_check_a_and_weights(a, weights):
9290

9391
if weights.shape != a.shape:
9492
raise ValueError("weights should have the same shape as a.")
95-
weights = dpt.reshape(weights, -1)
93+
weights = dpnp.ravel(weights)
9694

97-
a = dpt.reshape(a, -1)
95+
a = dpnp.ravel(a)
9896
return a, weights, usm_type
9997

10098

@@ -120,7 +118,7 @@ def _get_outer_edges(a, range):
120118
first_edge, last_edge = 0, 1
121119

122120
else:
123-
first_edge, last_edge = dpt.min(a), dpt.max(a)
121+
first_edge, last_edge = a.min(), a.max()
124122
if not (dpnp.isfinite(first_edge) and dpnp.isfinite(last_edge)):
125123
raise ValueError(
126124
f"autodetected range of [{first_edge}, {last_edge}] "
@@ -164,9 +162,11 @@ def _get_bin_edges(a, bins, range, usm_type):
164162
"a and bins must be allocated on the same SYCL queue"
165163
)
166164

167-
bin_edges = dpnp.as_usm_ndarray(
168-
bins, usm_type=usm_type, sycl_queue=sycl_queue
169-
)
165+
bin_edges = bins
166+
else:
167+
bin_edges = dpnp.asarray(
168+
bins, sycl_queue=sycl_queue, usm_type=usm_type
169+
)
170170

171171
if dpnp.any(bin_edges[:-1] > bin_edges[1:]):
172172
raise ValueError(
@@ -188,15 +188,15 @@ def _get_bin_edges(a, bins, range, usm_type):
188188
)
189189

190190
# bin edges must be computed
191-
bin_edges = dpnp_linspace(
191+
bin_edges = dpnp.linspace(
192192
first_edge,
193193
last_edge,
194194
n_equal_bins + 1,
195195
endpoint=True,
196196
dtype=bin_type,
197197
sycl_queue=sycl_queue,
198198
usm_type=usm_type,
199-
).get_array()
199+
)
200200
return bin_edges, (first_edge, last_edge, n_equal_bins)
201201
return bin_edges, None
202202

@@ -209,11 +209,8 @@ def _search_sorted_inclusive(a, v):
209209
210210
"""
211211

212-
return dpt.concat(
213-
(
214-
dpt.searchsorted(a, v[:-1], side="left"),
215-
dpt.searchsorted(a, v[-1:], side="right"),
216-
)
212+
return dpnp.concatenate(
213+
(a.searchsorted(v[:-1], "left"), a.searchsorted(v[-1:], "right"))
217214
)
218215

219216

@@ -305,14 +302,8 @@ def digitize(x, bins, right=False):
305302
# Use dpnp.searchsorted directly if bins are increasing
306303
return dpnp.searchsorted(bins, x, side=side)
307304

308-
usm_x = dpnp.get_usm_ndarray(x)
309-
usm_bins = dpnp.get_usm_ndarray(bins)
310-
311305
# Reverse bins and adjust indices if bins are decreasing
312-
usm_res = usm_bins.size - dpt.searchsorted(usm_bins[::-1], usm_x, side=side)
313-
314-
dpnp.synchronize_array_data(usm_res)
315-
return dpnp_array._create_from_usm_ndarray(usm_res)
306+
return bins.size - dpnp.searchsorted(bins[::-1], x, side=side)
316307

317308

318309
def histogram(a, bins=10, range=None, density=None, weights=None):
@@ -426,36 +417,26 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
426417
else:
427418
# Compute via cumulative histogram
428419
if weights is None:
429-
sa = dpt.sort(a)
420+
sa = dpnp.sort(a)
430421
cum_n = _search_sorted_inclusive(sa, bin_edges)
431422
else:
432-
zero = dpt.zeros(
423+
zero = dpnp.zeros(
433424
1, dtype=ntype, sycl_queue=a.sycl_queue, usm_type=usm_type
434425
)
435-
sorting_index = dpt.argsort(a)
426+
sorting_index = dpnp.argsort(a)
436427
sa = a[sorting_index]
437428
sw = weights[sorting_index]
438-
cw = dpt.concat((zero, dpt.cumulative_sum(sw, dtype=ntype)))
429+
cw = dpnp.concatenate((zero, sw.cumsum(dtype=ntype)))
439430
bin_index = _search_sorted_inclusive(sa, bin_edges)
440431
cum_n = cw[bin_index]
441432

442433
n = dpnp.diff(cum_n)
443434

444-
# convert bin_edges to dpnp.ndarray
445-
bin_edges = dpnp_array._create_from_usm_ndarray(bin_edges)
446-
447435
if density:
448436
# pylint: disable=possibly-used-before-assignment
449-
db = dpnp.diff(bin_edges)
450-
db = dpt.astype(db.get_array(), dpnp.default_float_type())
437+
db = dpnp.diff(bin_edges).astype(dpnp.default_float_type())
438+
return n / db / n.sum(), bin_edges
451439

452-
usm_n = n.get_array()
453-
hist = usm_n / db / dpt.sum(usm_n)
454-
455-
dpnp.synchronize_array_data(hist)
456-
return dpnp_array._create_from_usm_ndarray(hist), bin_edges
457-
458-
dpnp.synchronize_array_data(n)
459440
return n, bin_edges
460441

461442

@@ -541,6 +522,4 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None):
541522

542523
a, weights, usm_type = _ravel_check_a_and_weights(a, weights)
543524
bin_edges, _ = _get_bin_edges(a, bins, range, usm_type)
544-
545-
dpnp.synchronize_array_data(bin_edges)
546-
return dpnp_array._create_from_usm_ndarray(bin_edges)
525+
return bin_edges

0 commit comments

Comments
 (0)