Skip to content

Commit d4f5aa4

Browse files
Use unique_ptr as temporary owner of USM allocation
Until it is passed over to the host function, and unique_ptr's ownership is released. Also reduced allocation sizes, where too much was being allocated.
1 parent fa133b1 commit d4f5aa4

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <array>
3131
#include <cstdint>
3232
#include <limits>
33+
#include <memory>
3334
#include <stdexcept>
3435
#include <type_traits>
3536
#include <utility>
@@ -1590,10 +1591,13 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
15901591
// memory for storing count and offset values
15911592
CountT *count_ptr =
15921593
sycl::malloc_device<CountT>(n_iters * n_counts, exec_q);
1594+
15931595
if (nullptr == count_ptr) {
15941596
throw std::runtime_error("Could not allocate USM-device memory");
15951597
}
15961598

1599+
auto count_owner =
1600+
dpctl::tensor::alloc_utils::make_owner(count_ptr, exec_q);
15971601
constexpr std::uint32_t zero_radix_iter{0};
15981602

15991603
if constexpr (std::is_same_v<KeyT, bool>) {
@@ -1620,11 +1624,12 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16201624
ValueT *tmp_arr =
16211625
sycl::malloc_device<ValueT>(n_iters * n_to_sort, exec_q);
16221626
if (nullptr == tmp_arr) {
1623-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1624-
sycl_free_noexcept(count_ptr, exec_q);
16251627
throw std::runtime_error("Could not allocate USM-device memory");
16261628
}
16271629

1630+
auto tmp_arr_owner =
1631+
dpctl::tensor::alloc_utils::make_owner(tmp_arr, exec_q);
1632+
16281633
// iterations per each bucket
16291634
assert("Number of iterations must be even" && radix_iters % 2 == 0);
16301635
assert(radix_iters > 0);
@@ -1668,6 +1673,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16681673
sycl_free_noexcept(count_ptr, ctx);
16691674
});
16701675
});
1676+
count_owner.release();
1677+
tmp_arr_owner.release();
16711678
}
16721679

16731680
return sort_ev;
@@ -1769,13 +1776,13 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17691776
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
17701777

17711778
const std::size_t total_nelems = iter_nelems * sort_nelems;
1772-
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
1773-
IndexTy *workspace = sycl::malloc_device<IndexTy>(
1774-
padded_total_nelems + total_nelems, exec_q);
1779+
IndexTy *workspace = sycl::malloc_device<IndexTy>(total_nelems, exec_q);
17751780

17761781
if (nullptr == workspace) {
17771782
throw std::runtime_error("Could not allocate workspace on device");
17781783
}
1784+
auto workspace_owner =
1785+
dpctl::tensor::alloc_utils::make_owner(workspace, exec_q);
17791786

17801787
using IdentityProjT = radix_sort_details::IdentityProj;
17811788
using IndexedProjT =
@@ -1829,6 +1836,7 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
18291836
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
18301837
});
18311838

1839+
workspace_owner.release();
18321840
return cleanup_ev;
18331841
}
18341842

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
#include <cstdint>
3030
#include <iterator>
3131
#include <limits>
32+
#include <memory>
3233
#include <stdexcept>
33-
#include <sycl/sycl.hpp>
3434
#include <vector>
3535

36+
#include <sycl/sycl.hpp>
37+
3638
#include "kernels/dpctl_tensor_types.hpp"
3739
#include "kernels/sorting/merge_sort.hpp"
3840
#include "kernels/sorting/radix_sort.hpp"
@@ -95,6 +97,8 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
9597
if (index_data == nullptr) {
9698
throw std::runtime_error("Unable to allocate device_memory");
9799
}
100+
auto index_data_owner =
101+
dpctl::tensor::alloc_utils::make_owner(index_data, exec_q);
98102

99103
using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy, CompT>;
100104

@@ -161,6 +165,7 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
161165
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
162166
});
163167

168+
index_data_owner.release();
164169
return cleanup_host_task_event;
165170
};
166171

@@ -287,6 +292,8 @@ sycl::event topk_merge_impl(
287292
if (index_data == nullptr) {
288293
throw std::runtime_error("Unable to allocate device_memory");
289294
}
295+
auto index_data_owner =
296+
dpctl::tensor::alloc_utils::make_owner(index_data, exec_q);
290297

291298
// no need to populate index data: SLM will be populated with default
292299
// values
@@ -434,6 +441,7 @@ sycl::event topk_merge_impl(
434441
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
435442
});
436443

444+
index_data_owner.release();
437445
return cleanup_host_task_event;
438446
}
439447
}
@@ -479,15 +487,10 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
479487
throw std::runtime_error(
480488
"Not enough device memory for radix sort topk");
481489
}
490+
auto workspace_owner =
491+
dpctl::tensor::alloc_utils::make_owner(workspace, exec_q);
482492

483-
IndexTy *tmp_tp = sycl::malloc_device<IndexTy>(total_nelems, exec_q);
484-
485-
if (nullptr == tmp_tp) {
486-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
487-
sycl_free_noexcept(workspace, exec_q);
488-
throw std::runtime_error(
489-
"Not enough device memory for radix sort topk");
490-
}
493+
IndexTy *tmp_tp = workspace + padded_total_nelems;
491494

492495
using IdentityProjT = radix_sort_details::IdentityProj;
493496
using IndexedProjT =
@@ -546,12 +549,10 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
546549
const sycl::context &ctx = exec_q.get_context();
547550

548551
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
549-
cgh.host_task([ctx, workspace, tmp_tp] {
550-
sycl_free_noexcept(workspace, ctx);
551-
sycl_free_noexcept(tmp_tp, ctx);
552-
});
552+
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
553553
});
554554

555+
workspace_owner.release();
555556
return cleanup_ev;
556557
}
557558

dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <exception>
3030
#include <iostream>
31+
#include <memory>
3132

3233
#include "sycl/sycl.hpp"
3334

@@ -78,6 +79,28 @@ template <typename T> void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept
7879
sycl_free_noexcept(ptr, q.get_context());
7980
}
8081

82+
class USMDeleter
83+
{
84+
private:
85+
sycl::context ctx_;
86+
87+
public:
88+
USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {}
89+
USMDeleter(const sycl::context &ctx) : ctx_(ctx) {}
90+
91+
template <typename T> void operator()(T *ptr)
92+
{
93+
sycl_free_noexcept(ptr, ctx_);
94+
}
95+
};
96+
97+
template <typename T>
98+
std::unique_ptr<T, USMDeleter> make_owner(T *ptr, const sycl::queue &q)
99+
{
100+
auto usm_deleter = USMDeleter(q);
101+
return std::unique_ptr<T, USMDeleter>(ptr, usm_deleter);
102+
}
103+
81104
} // end of namespace alloc_utils
82105
} // end of namespace tensor
83106
} // end of namespace dpctl

0 commit comments

Comments
 (0)