Skip to content

Commit 0390cfe

Browse files
Apply utility queues_are_compatible
1 parent 3bf842d commit 0390cfe

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,13 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
524524
}
525525
}
526526

527-
// check same contexts
527+
// check compatibility of execution queue and allocation queue
528528
sycl::queue src_q = src.get_queue();
529529
sycl::queue dst_q = dst.get_queue();
530530

531-
sycl::context exec_ctx = exec_q.get_context();
532-
if (src_q.get_context() != exec_ctx || dst_q.get_context() != exec_ctx) {
531+
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
533532
throw py::value_error(
534-
"Execution queue context is not the same as allocation contexts");
533+
"Execution queue is not compatible with allocation queues");
535534
}
536535

537536
int src_typenum = src.get_typenum();
@@ -938,10 +937,9 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
938937
sycl::queue src_q = src.get_queue();
939938
sycl::queue dst_q = dst.get_queue();
940939

941-
sycl::context exec_ctx = exec_q.get_context();
942-
if (src_q.get_context() != exec_ctx || dst_q.get_context() != exec_ctx) {
940+
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
943941
throw py::value_error(
944-
"Execution queue context is not the same as allocation contexts");
942+
"Execution queue is not compatible with allocation queues");
945943
}
946944

947945
if (src_nelems == 1) {
@@ -1255,10 +1253,9 @@ void copy_numpy_ndarray_into_usm_ndarray(
12551253

12561254
sycl::queue dst_q = dst.get_queue();
12571255

1258-
sycl::context exec_ctx = exec_q.get_context();
1259-
if (dst_q.get_context() != exec_ctx) {
1260-
throw py::value_error("Execution queue context is not the same as the "
1261-
"allocation context");
1256+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
1257+
throw py::value_error("Execution queue is not compatible with the "
1258+
"allocation queue");
12621259
}
12631260

12641261
// here we assume that NumPy's type numbers agree with ours for types

0 commit comments

Comments
 (0)