Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit 3503271

Browse files
author
Mikolaj Komar
committed
Make local work with shared memory
1 parent 8e7f1fe commit 3503271

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

include/dr/mp/containers/matrix_formats/csr_row_distribution.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace dr::mp {
1212
template <typename T, typename I, class BackendT = MpiBackend>
1313
class csr_row_distribution {
14-
using view_tuple = std::tuple<std::size_t, std::size_t, I*>;
14+
using view_tuple = std::tuple<std::size_t, std::size_t, std::size_t, I*>;
1515
public:
1616
using value_type = dr::matrix_entry<T, I>;
1717
using segment_type = csr_row_segment<csr_row_distribution>;
@@ -290,7 +290,9 @@ class csr_row_distribution {
290290
}
291291
fence();
292292
auto local_rows = rows_data_->segments()[rank].begin().local();
293-
auto my_tuple = std::make_tuple(rows_data_->segment_size(), segment_size_ * rank, local_rows);
293+
auto offset = val_offsets_[rank];
294+
auto real_row_size = std::min(rows_data_->segment_size(), shape_.first - rows_data_->segment_size() * rank);
295+
auto my_tuple = std::make_tuple(real_row_size, segment_size_ * rank, offset, local_rows);
294296
view_helper_const = alloc.allocate(1);
295297

296298
view_helper_const[0] = my_tuple;
@@ -316,16 +318,13 @@ class csr_row_distribution {
316318

317319
auto transformer = [=](auto x) {
318320
auto [entry, tuple] = x;
319-
auto [row_size, offset, local_rows] = tuple;
320-
assert(offset == 0);
321-
assert(local_rows[0] == 0);
322-
assert(row_size == 10);
321+
auto [row_size, row_offset, offset, local_rows] = tuple;
323322
auto [index, pair] = entry;
324323
auto [val, column] = pair;
325324
auto row = rng::distance(
326325
local_rows,
327326
std::upper_bound(local_rows, local_rows + row_size, offset + index) -
328-
1);
327+
1) + row_offset;
329328
dr::index<index_type> index_obj(row, column);
330329
value_type entry_obj(index_obj, val);
331330
return entry_obj;

0 commit comments

Comments
 (0)