Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit 44a6e78

Browse files
author
Mikolaj Komar
committed
Fix device memory when using local in row distribution
1 parent 3503271 commit 44a6e78

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

include/dr/mp/algorithms/reduce.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ inline auto dpl_reduce(rng::forward_range auto &&r, auto &&binary_op) {
3535
sycl::known_identity_v<Fn, T>, binary_op);
3636
} else {
3737
dr::drlog.debug(" peel 1st value\n");
38-
auto base = *rng::begin(r);
3938
return std::reduce(dpl_policy(),
4039
dr::__detail::direct_iterator(rng::begin(r) + 1),
4140
dr::__detail::direct_iterator(rng::end(r)),
42-
sycl_get(base), binary_op);
41+
sycl_get_deref(rng::begin(r)), binary_op);
4342
}
4443
}
4544
#else

include/dr/mp/containers/matrix_formats/csr_row_distribution.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,12 @@ class csr_row_distribution {
295295
auto my_tuple = std::make_tuple(real_row_size, segment_size_ * rank, offset, local_rows);
296296
view_helper_const = alloc.allocate(1);
297297

298-
view_helper_const[0] = my_tuple;
298+
299+
if (use_sycl()) {
300+
sycl_queue().memcpy(view_helper_const, &my_tuple, sizeof(view_tuple)).wait();
301+
} else {
302+
view_helper_const[0] = my_tuple;
303+
}
299304

300305
local_view = std::make_shared<view_type>(get_elem_view(vals_size_, view_helper_const, cols_data_, vals_data_, rank));
301306
}

include/dr/mp/sycl_support.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ sycl::queue &sycl_queue();
1717

1818
namespace dr::mp::__detail {
1919

20+
//sometimes we only want to dereference iterator inside SYCL
21+
template <typename T> auto sycl_get_deref(T v) {
22+
using deref_type = decltype(*v);
23+
deref_type temp;
24+
{
25+
sycl::buffer<deref_type> buff(&temp, 1);
26+
sycl_queue().submit([&](auto &&h) {
27+
sycl::accessor access(buff, h, sycl::write_only, sycl::no_init);
28+
h.single_task([=](auto i) { access[0] = *v;});
29+
}).wait();
30+
}
31+
return temp;
32+
}
33+
2034
template <typename T> T sycl_get(T &v) {
2135
T temp;
2236
sycl_queue().memcpy(&temp, &v, sizeof(v)).wait();

0 commit comments

Comments
 (0)