Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit ba20ee3

Browse files
author
Mikolaj Komar
committed
Moved local view to distribution
1 parent 1f84ba7 commit ba20ee3

File tree

3 files changed

+77
-70
lines changed

3 files changed

+77
-70
lines changed

examples/mp/local_issue.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int main(int argc, char **argv) {
1818
dr::views::csr_matrix_view<double, long> local_data;
1919
auto root = 0;
2020
if (root == dr::mp::default_comm().rank()) {
21-
local_data = dr::generate_band_csr<double, long>(100, 2, 2);
21+
local_data = dr::generate_band_csr<double, long>(10, 0, 1);
2222
}
2323
{
2424
mp::distributed_sparse_matrix<
@@ -30,8 +30,13 @@ int main(int argc, char **argv) {
3030
auto [n, ma] = ind;
3131
fmt::print("some res 2 {} {} {}\n", val, n, ma);
3232

33+
auto mapper = [] (auto elem) { auto [a, b] = elem; auto [c, d] = a; return d;};
34+
auto summer = [](auto x, auto y) { return x + y;};
35+
auto z2 = dr::transform_view(m_row, mapper);
36+
auto red2 = dr::mp::reduce(z2, 0, summer);
37+
fmt::print("reduced row {} {}\n", red2, m_row.size());
3338
}
34-
39+
3540
if (root == dr::mp::default_comm().rank()) {
3641
dr::__detail::destroy_csr_matrix_view(local_data, std::allocator<double>{});
3742
}

include/dr/mp/containers/matrix_formats/csr_row_distribution.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@
88
#include <fmt/core.h>
99

1010
namespace dr::mp {
11+
namespace __detail {
12+
template <typename T, typename V>
13+
class transform_fn_1 {
14+
public:
15+
using value_type = V;
16+
using index_type = T;
17+
transform_fn_1(std::size_t offset, std::size_t row_size, T* row_ptr):
18+
offset_(offset), row_size_(row_size), row_ptr_(row_ptr) {
19+
assert(offset_ == 0);
20+
}
21+
22+
~transform_fn_1() {
23+
destroyed = true;
24+
}
25+
template <typename P>
26+
auto operator()(P entry) const {
27+
assert(offset_ == 0);
28+
assert(!destroyed);
29+
auto [index, pair] = entry;
30+
auto [val, column] = pair;
31+
auto row = 0;
32+
// auto row = rng::distance(
33+
// row_ptr_,
34+
// std::upper_bound(row_ptr_, row_ptr_ + row_size_, offset_ + index) -
35+
// 1);
36+
dr::index<index_type> index_obj(row, column);
37+
value_type entry_obj(index_obj, val);
38+
return entry_obj;
39+
}
40+
private:
41+
bool destroyed = false;
42+
std::size_t offset_;
43+
std::size_t row_size_;
44+
T* row_ptr_;
45+
};
46+
}
1147

1248
template <typename T, typename I, class BackendT = MpiBackend>
1349
class csr_row_distribution {
@@ -287,8 +323,41 @@ class csr_row_distribution {
287323
std::max(val_sizes_[i], static_cast<std::size_t>(1)));
288324
}
289325
fence();
326+
local_view = get_elem_view(vals_size_, cols_data_, vals_data_, rows_data_, rank);
290327
}
291328

329+
static auto get_elem_view(std::size_t vals_size,
330+
index_type *local_cols,
331+
elem_type *local_vals,
332+
std::shared_ptr<distributed_vector<I>> rows_data,
333+
std::size_t rank) {
334+
auto row_size = rows_data->segment_size();
335+
std::size_t offset = row_size * rank;
336+
auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
337+
auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
338+
// auto local_rows = rows_data->segments()[rank].begin().local();
339+
auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
340+
auto enumerated_zipped = rng::views::enumerate(zipped_results);
341+
auto transformer = [=](auto entry){
342+
assert(offset == 0);
343+
auto [index, pair] = entry;
344+
auto [val, column] = pair;
345+
auto row = 0;
346+
// auto row = rng::distance(
347+
// local_rows,
348+
// std::upper_bound(local_rows, local_rows + row_size, offset_ + index) -
349+
// 1);
350+
dr::index<index_type> index_obj(row, column);
351+
value_type entry_obj(index_obj, val);
352+
return entry_obj;
353+
};
354+
//__detail::transform_fn_1<index_type, value_type>(offset, row_size, local_rows);
355+
return rng::views::transform(enumerated_zipped, transformer);
356+
}
357+
358+
using view_type = decltype(get_elem_view(0, nullptr, nullptr, std::shared_ptr<distributed_vector<I>>(nullptr),0));
359+
360+
view_type local_view;
292361
std::size_t segment_size_ = 0;
293362
std::size_t vals_size_ = 0;
294363
std::vector<std::size_t> val_offsets_;

include/dr/mp/containers/matrix_formats/csr_row_segment.hpp

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,7 @@
55
#pragma once
66

77

8-
int some_id_base =0;
98
namespace dr::mp {
10-
namespace __detail {
11-
template <typename T, typename V>
12-
class transform_fn_1 {
13-
public:
14-
using value_type = V;
15-
using index_type = T;
16-
transform_fn_1(std::size_t offset, std::size_t row_size, T* row_ptr):
17-
offset_(offset), row_size_(row_size), row_ptr_(row_ptr) {
18-
assert(offset_ == 0);
19-
myid = some_id_base++;
20-
fmt::print("created {}\n", myid);
21-
}
22-
23-
~transform_fn_1() {
24-
destroyed = true;
25-
fmt::print("destroyed {}\n", myid);
26-
}
27-
template <typename P>
28-
auto operator()(P entry) const {
29-
fmt::print("called {}\n", myid);
30-
assert(offset_ == 0);
31-
assert(!destroyed);
32-
auto [index, pair] = entry;
33-
auto [val, column] = pair;
34-
auto row = rng::distance(
35-
row_ptr_,
36-
std::upper_bound(row_ptr_, row_ptr_ + row_size_, offset_) -
37-
1);
38-
dr::index<index_type> index_obj(row, column);
39-
value_type entry_obj(index_obj, val);
40-
return entry_obj;
41-
}
42-
private:
43-
int myid = 0;
44-
bool destroyed = false;
45-
std::size_t offset_;
46-
std::size_t row_size_;
47-
T* row_ptr_;
48-
};
49-
}
509
template <typename DSM> class csr_row_segment_iterator;
5110

5211
template <typename DSM> class csr_row_segment_reference {
@@ -96,10 +55,6 @@ template <typename DSM> class csr_row_segment_iterator {
9655
dsm_ = dsm;
9756
segment_index_ = segment_index;
9857
index_ = index;
99-
if (dsm_->vals_backend_.getrank() == segment_index_) {
100-
elem_view_ = get_elem_view(dsm_, segment_index);
101-
base_iter = elem_view_.begin();
102-
}
10358
}
10459

10560
auto operator<=>(const csr_row_segment_iterator &other) const noexcept {
@@ -273,35 +228,13 @@ template <typename DSM> class csr_row_segment_iterator {
273228
auto local() const {
274229
const auto my_process_segment_index = dsm_->vals_backend_.getrank();
275230
assert(my_process_segment_index == segment_index_);
276-
auto [a, b] = *base_iter;
277-
auto [c, d] = a;
278-
fmt::print("aqwsedrftgyhuji {} {} {}\n", b, c, d);
279-
return base_iter;
231+
return dsm_->local_view.begin();
280232
}
281233

282234
private:
283235

284-
static auto get_elem_view(DSM *dsm, std::size_t segment_index) {
285-
std::size_t offset = dsm->segment_size_ * segment_index;
286-
auto row_size = dsm->segment_size_;
287-
auto vals_size = dsm->vals_size_;
288-
auto local_vals = dsm->vals_data_;
289-
auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
290-
auto local_cols = dsm->cols_data_;
291-
auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
292-
auto local_rows = dsm->rows_data_->segments()[segment_index].begin().local();
293-
auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
294-
auto enumerated_zipped = rng::views::enumerate(zipped_results);
295-
auto transformer = __detail::transform_fn_1<index_type, value_type>(offset, row_size, local_rows);
296-
return rng::views::transform(enumerated_zipped, transformer);
297-
}
298-
299236
// all fields need to be initialized by default ctor so every default
300237
// constructed iter is equal to any other default constructed iter
301-
using view_type = decltype(get_elem_view(std::declval<DSM*>(), 0));
302-
using iter_type = rng::iterator_t<view_type>;
303-
view_type elem_view_;
304-
iter_type base_iter;
305238
DSM *dsm_ = nullptr;
306239
std::size_t segment_index_ = 0;
307240
std::size_t index_ = 0;

0 commit comments

Comments
 (0)