Skip to content

Commit fb50c2e

Browse files
authored
Adopt linalg functions from LAPACK extension to asynchronous dpctl execution (#1922)
* Update LinAlg functions from BLAS routine * Decouple einsum utils function to separate file * Update LinAlg functions from LAPACK routine * Removed batch_call from dpnp_svd()
1 parent f9cbc62 commit fb50c2e

File tree

2 files changed

+212
-194
lines changed

2 files changed

+212
-194
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,14 @@ def _copy_array(x, copy_flag=False, dtype=None, order="C"):
116116

117117
exec_q = x_copy.sycl_queue
118118
_manager = dpu.SequentialOrderManager[exec_q]
119-
dep_evs = _manager.submitted_events
120119

121-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
120+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
122121
src=dpnp.get_usm_ndarray(x),
123122
dst=x_copy.get_array(),
124123
sycl_queue=exec_q,
125-
depends=dep_evs,
124+
depends=_manager.submitted_events,
126125
)
127-
_manager.add_event_pair(ht_copy_ev, copy_ev)
126+
_manager.add_event_pair(ht_ev, copy_ev)
128127
return x_copy
129128
return x
130129

@@ -356,14 +355,14 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
356355
x2_usm = dpnp.get_usm_ndarray(x2[i : i + chunk, ...])
357356
res_usm = dpnp.get_usm_ndarray(res[i : i + chunk, ...])
358357

359-
ht_blas_ev, blas_ev, row_major = bi._gemm_batch(
358+
ht_ev, blas_ev, row_major = bi._gemm_batch(
360359
exec_q,
361360
x1_usm,
362361
x2_usm,
363362
res_usm,
364363
depends=_manager.submitted_events,
365364
)
366-
_manager.add_event_pair(ht_blas_ev, blas_ev)
365+
_manager.add_event_pair(ht_ev, blas_ev)
367366

368367
res_shape = res.shape
369368
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
@@ -388,14 +387,15 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
388387

389388
def _gemm_matmul(exec_q, x1, x2, res):
390389
_manager = dpu.SequentialOrderManager[exec_q]
391-
ht_gemm_ev, gemm_ev, row_major = bi._gemm(
390+
391+
ht_ev, gemm_ev, row_major = bi._gemm(
392392
exec_q,
393393
dpnp.get_usm_ndarray(x1),
394394
dpnp.get_usm_ndarray(x2),
395395
dpnp.get_usm_ndarray(res),
396396
depends=_manager.submitted_events,
397397
)
398-
_manager.add_event_pair(ht_gemm_ev, gemm_ev)
398+
_manager.add_event_pair(ht_ev, gemm_ev)
399399

400400
if row_major:
401401
if res.flags.f_contiguous is True:
@@ -635,14 +635,14 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
635635
else:
636636
dot_func = "_dot"
637637

638-
ht_dot_ev, dot_ev = getattr(bi, dot_func)(
638+
ht_ev, dot_ev = getattr(bi, dot_func)(
639639
exec_q,
640640
dpnp.get_usm_ndarray(a),
641641
dpnp.get_usm_ndarray(b),
642642
dpnp.get_usm_ndarray(result),
643643
depends=_manager.submitted_events,
644644
)
645-
_manager.add_event_pair(ht_dot_ev, dot_ev)
645+
_manager.add_event_pair(ht_ev, dot_ev)
646646
else:
647647
# oneapi::mkl::blas::dot is slow for integer data type,
648648
# so using dpctl.tensor.vecdot instead
@@ -866,15 +866,16 @@ def dpnp_matmul(
866866
x_usm = dpnp.get_usm_ndarray(x2)
867867

868868
_manager = dpu.SequentialOrderManager[exec_q]
869-
ht_gemv_ev, gemv_ev = bi._gemv(
869+
870+
ht_ev, gemv_ev = bi._gemv(
870871
exec_q,
871872
a_usm,
872873
x_usm,
873874
dpnp.get_usm_ndarray(result),
874875
transpose,
875876
depends=_manager.submitted_events,
876877
)
877-
_manager.add_event_pair(ht_gemv_ev, gemv_ev)
878+
_manager.add_event_pair(ht_ev, gemv_ev)
878879
elif call_flag == "gemm":
879880
result = _gemm_matmul(
880881
exec_q,

0 commit comments

Comments
 (0)