Skip to content

Commit bbb55f1

Browse files
Introduce async_smart_free
Thhis function intends to replace use of host_task submissions to manage USM temporary deallocations. Signature sycl::event async_smart_free( sycl::queue &, const std::vector<sycl::event> &, std::unique_ptr<T, USMDeleter>, ...);
1 parent 535e471 commit bbb55f1

File tree

3 files changed

+45
-59
lines changed

3 files changed

+45
-59
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,14 +1605,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16051605
n_counts, count_ptr, proj_op,
16061606
is_ascending, depends);
16071607

1608-
sort_ev = exec_q.submit([=](sycl::handler &cgh) {
1609-
cgh.depends_on(sort_ev);
1610-
const sycl::context &ctx = exec_q.get_context();
1611-
1612-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1613-
cgh.host_task(
1614-
[ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); });
1615-
});
1608+
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
1609+
exec_q, {sort_ev}, count_owner);
16161610

16171611
return sort_ev;
16181612
}
@@ -1655,19 +1649,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16551649
}
16561650
}
16571651

1658-
sort_ev = exec_q.submit([=](sycl::handler &cgh) {
1659-
cgh.depends_on(sort_ev);
1660-
1661-
const sycl::context &ctx = exec_q.get_context();
1662-
1663-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1664-
cgh.host_task([ctx, count_ptr, tmp_arr]() {
1665-
sycl_free_noexcept(tmp_arr, ctx);
1666-
sycl_free_noexcept(count_ptr, ctx);
1667-
});
1668-
});
1669-
count_owner.release();
1670-
tmp_arr_owner.release();
1652+
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
1653+
exec_q, {sort_ev}, tmp_arr_owner, count_owner);
16711654
}
16721655

16731656
return sort_ev;
@@ -1819,16 +1802,9 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
18191802
});
18201803
});
18211804

1822-
sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
1823-
cgh.depends_on(map_back_ev);
1824-
1825-
const sycl::context &ctx = exec_q.get_context();
1826-
1827-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1828-
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
1829-
});
1805+
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1806+
exec_q, {map_back_ev}, workspace_owner);
18301807

1831-
workspace_owner.release();
18321808
return cleanup_ev;
18331809
}
18341810

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,9 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
154154
});
155155

156156
sycl::event cleanup_host_task_event =
157-
exec_q.submit([&](sycl::handler &cgh) {
158-
cgh.depends_on(write_out_ev);
159-
const sycl::context &ctx = exec_q.get_context();
160-
161-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
162-
cgh.host_task(
163-
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
164-
});
157+
dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev},
158+
index_data_owner);
165159

166-
index_data_owner.release();
167160
return cleanup_host_task_event;
168161
};
169162

@@ -429,16 +422,9 @@ sycl::event topk_merge_impl(
429422
});
430423

431424
sycl::event cleanup_host_task_event =
432-
exec_q.submit([&](sycl::handler &cgh) {
433-
cgh.depends_on(write_topk_ev);
434-
const sycl::context &ctx = exec_q.get_context();
435-
436-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
437-
cgh.host_task(
438-
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
439-
});
425+
dpctl::tensor::alloc_utils::async_smart_free(
426+
exec_q, {write_topk_ev}, index_data_owner);
440427

441-
index_data_owner.release();
442428
return cleanup_host_task_event;
443429
}
444430
}
@@ -537,16 +523,9 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
537523
});
538524
});
539525

540-
sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
541-
cgh.depends_on(write_topk_ev);
542-
543-
const sycl::context &ctx = exec_q.get_context();
544-
545-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
546-
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
547-
});
526+
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
527+
exec_q, {write_topk_ev}, workspace_owner);
548528

549-
workspace_owner.release();
550529
return cleanup_ev;
551530
}
552531

dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <iostream>
3131
#include <memory>
3232
#include <stdexcept>
33+
#include <vector>
3334

3435
#include "sycl/sycl.hpp"
3536

@@ -75,7 +76,8 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept
7576
}
7677
}
7778

78-
template <typename T> void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept
79+
template <typename T>
80+
void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept
7981
{
8082
sycl_free_noexcept(ptr, q.get_context());
8183
}
@@ -89,7 +91,7 @@ class USMDeleter
8991
USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {}
9092
USMDeleter(const sycl::context &ctx) : ctx_(ctx) {}
9193

92-
template <typename T> void operator()(T *ptr)
94+
template <typename T> void operator()(T *ptr) const
9395
{
9496
sycl_free_noexcept(ptr, ctx_);
9597
}
@@ -138,6 +140,35 @@ smart_malloc_jost(std::size_t count,
138140
return smart_malloc<T>(count, q, sycl::usm::alloc::host, propList);
139141
}
140142

143+
template <typename... Args>
144+
sycl::event async_smart_free(sycl::queue &exec_q,
145+
const std::vector<sycl::event> &depends,
146+
Args &&...args)
147+
{
148+
constexpr std::size_t n = sizeof...(Args);
149+
150+
std::vector<void *> ptrs;
151+
ptrs.reserve(n);
152+
(ptrs.push_back(reinterpret_cast<void *>(args.get())), ...);
153+
154+
std::vector<USMDeleter> dels;
155+
dels.reserve(n);
156+
(dels.push_back(args.get_deleter()), ...);
157+
158+
sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) {
159+
cgh.depends_on(depends);
160+
161+
cgh.host_task([ptrs, dels]() {
162+
for (size_t i = 0; i < ptrs.size(); ++i) {
163+
dels[i](ptrs[i]);
164+
}
165+
});
166+
});
167+
(args.release(), ...);
168+
169+
return ht_e;
170+
}
171+
141172
} // end of namespace alloc_utils
142173
} // end of namespace tensor
143174
} // end of namespace dpctl

0 commit comments

Comments
 (0)