Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit 1f84ba7

Browse files
author
Mikolaj Komar
committed
Add problem to review
1 parent 4acbad6 commit 1f84ba7

File tree

4 files changed

+109
-26
lines changed

4 files changed

+109
-26
lines changed

examples/mp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_mp_example(hello_world)
3434
add_mp_example_no_test(sparse_matrix)
3535
add_mp_example_no_test(sparse_benchmark)
3636
add_mp_example_no_test(sparse_matrix_matrix_mul)
37+
add_mp_example_no_test(local_issue)
3738

3839
if(OpenMP_FOUND)
3940
add_executable(vector-add-ref vector-add-ref.cpp)

examples/mp/local_issue.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// SPDX-FileCopyrightText: Intel Corporation
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <dr/mp.hpp>
6+
#include <fmt/core.h>
7+
8+
namespace mp = dr::mp;
9+
10+
int main(int argc, char **argv) {
11+
12+
#ifdef SYCL_LANGUAGE_VERSION
13+
mp::init(sycl::default_selector_v);
14+
#else
15+
mp::init();
16+
#endif
17+
18+
dr::views::csr_matrix_view<double, long> local_data;
19+
auto root = 0;
20+
if (root == dr::mp::default_comm().rank()) {
21+
local_data = dr::generate_band_csr<double, long>(100, 2, 2);
22+
}
23+
{
24+
mp::distributed_sparse_matrix<
25+
double, long, dr::mp::MpiBackend,
26+
dr::mp::csr_row_distribution<double, long, dr::mp::MpiBackend>>
27+
m_row(local_data, root);
28+
auto b = m_row.segments()[0].begin().local();
29+
auto [ind, val] = *b;
30+
auto [n, ma] = ind;
31+
fmt::print("some res 2 {} {} {}\n", val, n, ma);
32+
33+
}
34+
35+
if (root == dr::mp::default_comm().rank()) {
36+
dr::__detail::destroy_csr_matrix_view(local_data, std::allocator<double>{});
37+
}
38+
mp::finalize();
39+
40+
return 0;
41+
}

include/dr/mp/containers/matrix_formats/csr_row_segment.hpp

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,49 @@
44

55
#pragma once
66

7+
8+
int some_id_base =0;
79
namespace dr::mp {
10+
namespace __detail {
11+
template <typename T, typename V>
12+
class transform_fn_1 {
13+
public:
14+
using value_type = V;
15+
using index_type = T;
16+
transform_fn_1(std::size_t offset, std::size_t row_size, T* row_ptr):
17+
offset_(offset), row_size_(row_size), row_ptr_(row_ptr) {
18+
assert(offset_ == 0);
19+
myid = some_id_base++;
20+
fmt::print("created {}\n", myid);
21+
}
822

23+
~transform_fn_1() {
24+
destroyed = true;
25+
fmt::print("destroyed {}\n", myid);
26+
}
27+
template <typename P>
28+
auto operator()(P entry) const {
29+
fmt::print("called {}\n", myid);
30+
assert(offset_ == 0);
31+
assert(!destroyed);
32+
auto [index, pair] = entry;
33+
auto [val, column] = pair;
34+
auto row = rng::distance(
35+
row_ptr_,
36+
std::upper_bound(row_ptr_, row_ptr_ + row_size_, offset_) -
37+
1);
38+
dr::index<index_type> index_obj(row, column);
39+
value_type entry_obj(index_obj, val);
40+
return entry_obj;
41+
}
42+
private:
43+
int myid = 0;
44+
bool destroyed = false;
45+
std::size_t offset_;
46+
std::size_t row_size_;
47+
T* row_ptr_;
48+
};
49+
}
950
template <typename DSM> class csr_row_segment_iterator;
1051

1152
template <typename DSM> class csr_row_segment_reference {
@@ -55,6 +96,10 @@ template <typename DSM> class csr_row_segment_iterator {
5596
dsm_ = dsm;
5697
segment_index_ = segment_index;
5798
index_ = index;
99+
if (dsm_->vals_backend_.getrank() == segment_index_) {
100+
elem_view_ = get_elem_view(dsm_, segment_index);
101+
base_iter = elem_view_.begin();
102+
}
58103
}
59104

60105
auto operator<=>(const csr_row_segment_iterator &other) const noexcept {
@@ -227,40 +272,36 @@ template <typename DSM> class csr_row_segment_iterator {
227272

228273
auto local() const {
229274
const auto my_process_segment_index = dsm_->vals_backend_.getrank();
230-
231275
assert(my_process_segment_index == segment_index_);
232-
std::size_t offset = dsm_->segment_size_ * segment_index_;
233-
assert(offset == 0);
234-
// auto row_size = dsm_->segment_size_;
235-
auto vals_size = dsm_->vals_size_;
236-
auto local_vals = dsm_->vals_data_;
276+
auto [a, b] = *base_iter;
277+
auto [c, d] = a;
278+
fmt::print("aqwsedrftgyhuji {} {} {}\n", b, c, d);
279+
return base_iter;
280+
}
281+
282+
private:
283+
284+
static auto get_elem_view(DSM *dsm, std::size_t segment_index) {
285+
std::size_t offset = dsm->segment_size_ * segment_index;
286+
auto row_size = dsm->segment_size_;
287+
auto vals_size = dsm->vals_size_;
288+
auto local_vals = dsm->vals_data_;
237289
auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
238-
auto local_cols = dsm_->cols_data_;
290+
auto local_cols = dsm->cols_data_;
239291
auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
240-
// auto local_rows = dsm_->rows_data_->segments()[segment_index_].begin().local();
292+
auto local_rows = dsm->rows_data_->segments()[segment_index].begin().local();
241293
auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
242294
auto enumerated_zipped = rng::views::enumerate(zipped_results);
243-
auto transformer = [=](auto entry) {
244-
assert(offset == 0);
245-
auto [index, pair] = entry;
246-
auto [val, column] = pair;
247-
auto row = 0; //TODO fix calculating row - it results in segfault
248-
// problem originates from the fact that variables cannot be caputed properly by value
249-
// auto row = rng::distance(
250-
// local_rows,
251-
// std::upper_bound(local_rows, local_rows + row_size, offset) -
252-
// 1);
253-
dr::index<index_type> index_obj(row, column);
254-
value_type entry_obj(index_obj, val);
255-
return entry_obj;
256-
};
257-
auto transformed_res = rng::views::transform(enumerated_zipped, transformer);
258-
return transformed_res.begin();
295+
auto transformer = __detail::transform_fn_1<index_type, value_type>(offset, row_size, local_rows);
296+
return rng::views::transform(enumerated_zipped, transformer);
259297
}
260298

261-
private:
262299
// all fields need to be initialized by default ctor so every default
263300
// constructed iter is equal to any other default constructed iter
301+
using view_type = decltype(get_elem_view(std::declval<DSM*>(), 0));
302+
using iter_type = rng::iterator_t<view_type>;
303+
view_type elem_view_;
304+
iter_type base_iter;
264305
DSM *dsm_ = nullptr;
265306
std::size_t segment_index_ = 0;
266307
std::size_t index_ = 0;

include/dr/views/transform.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class transform_iterator {
109109
requires(dr::ranges::__detail::has_local<Iter>)
110110
{
111111
auto iter = dr::ranges::__detail::local(iter_);
112-
return transform_iterator<decltype(iter), F>(iter, fn_);
112+
return transform_iterator<decltype(iter), F>(std::move(iter), fn_);
113113
}
114114

115115
private:

0 commit comments

Comments
 (0)