@@ -149,8 +149,9 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
149
149
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
150
150
std::size_t dst_idx = iter_gid * k + axis_gid;
151
151
152
- auto res_ind = index_data[src_idx];
153
- vals_tp[dst_idx] = arg_tp[res_ind];
152
+ const IndexTy res_ind = index_data[src_idx];
153
+ const argTy v = arg_tp[res_ind];
154
+ vals_tp[dst_idx] = v;
154
155
inds_tp[dst_idx] = res_ind % axis_nelems;
155
156
});
156
157
});
@@ -425,8 +426,9 @@ sycl::event topk_merge_impl(
425
426
const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
426
427
const std::size_t dst_idx = gid;
427
428
428
- const auto res_ind = index_data[src_idx];
429
- vals_tp[dst_idx] = arg_tp[res_ind];
429
+ const IndexTy res_ind = index_data[src_idx];
430
+ const argTy v = arg_tp[res_ind];
431
+ vals_tp[dst_idx] = v;
430
432
inds_tp[dst_idx] = (res_ind % axis_nelems);
431
433
});
432
434
});
@@ -538,11 +540,14 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
538
540
const std::size_t dst_idx = gid;
539
541
540
542
const IndexTy res_ind = tmp_tp[src_idx];
541
- vals_tp[dst_idx] = arg_tp[res_ind];
543
+ const v = arg_tp[res_ind];
544
+ vals_tp[dst_idx] = v;
542
545
inds_tp[dst_idx] = (res_ind % axis_nelems);
543
546
});
544
547
});
545
548
549
+ write_topk_ev.wait ();
550
+
546
551
sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
547
552
cgh.depends_on (write_topk_ev);
548
553
0 commit comments