@@ -116,15 +116,14 @@ def _copy_array(x, copy_flag=False, dtype=None, order="C"):
116
116
117
117
exec_q = x_copy .sycl_queue
118
118
_manager = dpu .SequentialOrderManager [exec_q ]
119
- dep_evs = _manager .submitted_events
120
119
121
- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
120
+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
122
121
src = dpnp .get_usm_ndarray (x ),
123
122
dst = x_copy .get_array (),
124
123
sycl_queue = exec_q ,
125
- depends = dep_evs ,
124
+ depends = _manager . submitted_events ,
126
125
)
127
- _manager .add_event_pair (ht_copy_ev , copy_ev )
126
+ _manager .add_event_pair (ht_ev , copy_ev )
128
127
return x_copy
129
128
return x
130
129
@@ -356,14 +355,14 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
356
355
x2_usm = dpnp .get_usm_ndarray (x2 [i : i + chunk , ...])
357
356
res_usm = dpnp .get_usm_ndarray (res [i : i + chunk , ...])
358
357
359
- ht_blas_ev , blas_ev , row_major = bi ._gemm_batch (
358
+ ht_ev , blas_ev , row_major = bi ._gemm_batch (
360
359
exec_q ,
361
360
x1_usm ,
362
361
x2_usm ,
363
362
res_usm ,
364
363
depends = _manager .submitted_events ,
365
364
)
366
- _manager .add_event_pair (ht_blas_ev , blas_ev )
365
+ _manager .add_event_pair (ht_ev , blas_ev )
367
366
368
367
res_shape = res .shape
369
368
_ , res_is_c_contig , res_is_f_contig = _define_contig_flag (res )
@@ -388,14 +387,15 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
388
387
389
388
def _gemm_matmul (exec_q , x1 , x2 , res ):
390
389
_manager = dpu .SequentialOrderManager [exec_q ]
391
- ht_gemm_ev , gemm_ev , row_major = bi ._gemm (
390
+
391
+ ht_ev , gemm_ev , row_major = bi ._gemm (
392
392
exec_q ,
393
393
dpnp .get_usm_ndarray (x1 ),
394
394
dpnp .get_usm_ndarray (x2 ),
395
395
dpnp .get_usm_ndarray (res ),
396
396
depends = _manager .submitted_events ,
397
397
)
398
- _manager .add_event_pair (ht_gemm_ev , gemm_ev )
398
+ _manager .add_event_pair (ht_ev , gemm_ev )
399
399
400
400
if row_major :
401
401
if res .flags .f_contiguous is True :
@@ -635,14 +635,14 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
635
635
else :
636
636
dot_func = "_dot"
637
637
638
- ht_dot_ev , dot_ev = getattr (bi , dot_func )(
638
+ ht_ev , dot_ev = getattr (bi , dot_func )(
639
639
exec_q ,
640
640
dpnp .get_usm_ndarray (a ),
641
641
dpnp .get_usm_ndarray (b ),
642
642
dpnp .get_usm_ndarray (result ),
643
643
depends = _manager .submitted_events ,
644
644
)
645
- _manager .add_event_pair (ht_dot_ev , dot_ev )
645
+ _manager .add_event_pair (ht_ev , dot_ev )
646
646
else :
647
647
# oneapi::mkl::blas::dot is slow for integer data type,
648
648
# so using dpctl.tensor.vecdot instead
@@ -866,15 +866,16 @@ def dpnp_matmul(
866
866
x_usm = dpnp .get_usm_ndarray (x2 )
867
867
868
868
_manager = dpu .SequentialOrderManager [exec_q ]
869
- ht_gemv_ev , gemv_ev = bi ._gemv (
869
+
870
+ ht_ev , gemv_ev = bi ._gemv (
870
871
exec_q ,
871
872
a_usm ,
872
873
x_usm ,
873
874
dpnp .get_usm_ndarray (result ),
874
875
transpose ,
875
876
depends = _manager .submitted_events ,
876
877
)
877
- _manager .add_event_pair (ht_gemv_ev , gemv_ev )
878
+ _manager .add_event_pair (ht_ev , gemv_ev )
878
879
elif call_flag == "gemm" :
879
880
result = _gemm_matmul (
880
881
exec_q ,
0 commit comments