Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit e42cfa2

Browse files
author
Mikolaj Komar
committed
Fix issue when distributed vector is too small
1 parent dd1d6ed commit e42cfa2

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

include/dr/mp/containers/matrix_formats/csr_eq_distribution.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ class csr_eq_distribution {
301301
segments_.emplace_back(this, segment_index++,
302302
std::min(segment_size_, nnz_ - i), segment_size_);
303303
}
304-
fence();
305304
auto local_rows = rows_data_;
306305
auto real_val_size = std::min(vals_data_->segment_size(),
307306
nnz_ - vals_data_->segment_size() * rank);
@@ -317,10 +316,15 @@ class csr_eq_distribution {
317316
view_helper_const[0] = my_tuple;
318317
}
319318

320-
auto local_cols = cols_data_->segments()[rank].begin().local();
321-
auto local_vals = vals_data_->segments()[rank].begin().local();
322-
local_view = std::make_shared<view_type>(get_elem_view(
323-
real_val_size, view_helper_const, local_cols, local_vals, rank));
319+
auto local_cols = static_cast<I *>(nullptr);
320+
auto local_vals = static_cast<T *>(nullptr);
321+
if (cols_data_->segments().size() > rank) {
322+
local_cols = cols_data_->segments()[rank].begin().local();
323+
local_vals = vals_data_->segments()[rank].begin().local();
324+
local_view = std::make_shared<view_type>(get_elem_view(
325+
real_val_size, view_helper_const, local_cols, local_vals, rank));
326+
}
327+
fence();
324328
}
325329

326330
static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
@@ -360,7 +364,7 @@ class csr_eq_distribution {
360364

361365
dr::mp::__detail::allocator<view_tuple> tuple_alloc;
362366
view_tuple *view_helper_const;
363-
std::shared_ptr<view_type> local_view;
367+
std::shared_ptr<view_type> local_view = nullptr;
364368

365369
std::size_t segment_size_ = 0;
366370
std::size_t row_size_ = 0;

include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ template <typename DSM> class csr_eq_segment_iterator {
239239
auto local() const {
240240
const auto my_process_segment_index = dsm_->rows_backend_.getrank();
241241
assert(my_process_segment_index == segment_index_);
242+
if (dsm_->local_view == nullptr) {
243+
return nullptr;
244+
}
242245
return dsm_->local_view->begin();
243246
}
244247

include/dr/mp/containers/matrix_formats/csr_row_distribution.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,11 @@ class csr_row_distribution {
289289
this, segment_index++, val_sizes_[i],
290290
std::max(val_sizes_[i], static_cast<std::size_t>(1)));
291291
}
292-
fence();
293-
auto local_rows = rows_data_->segments()[rank].begin().local();
292+
293+
auto local_rows = static_cast<I *>(nullptr);
294+
if (rows_data_->segments().size() > rank) {
295+
local_rows = rows_data_->segments()[rank].begin().local();
296+
}
294297
auto offset = val_offsets_[rank];
295298
auto real_row_size =
296299
std::min(rows_data_->segment_size(),
@@ -307,8 +310,11 @@ class csr_row_distribution {
307310
view_helper_const[0] = my_tuple;
308311
}
309312

310-
local_view = std::make_shared<view_type>(get_elem_view(
311-
vals_size_, view_helper_const, cols_data_, vals_data_, rank));
313+
if (rows_data_->segments().size() > rank) {
314+
local_view = std::make_shared<view_type>(get_elem_view(
315+
vals_size_, view_helper_const, cols_data_, vals_data_, rank));
316+
}
317+
fence();
312318
}
313319

314320
static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
@@ -365,6 +371,6 @@ class csr_row_distribution {
365371
dr::index<size_t> shape_;
366372
std::size_t nnz_;
367373
std::vector<segment_type> segments_;
368-
std::shared_ptr<distributed_vector<I>> rows_data_;
374+
std::shared_ptr<distributed_vector<I>> rows_data_ = nullptr;
369375
};
370376
} // namespace dr::mp

include/dr/mp/containers/matrix_formats/csr_row_segment.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ template <typename DSM> class csr_row_segment_iterator {
227227
auto local() const {
228228
const auto my_process_segment_index = dsm_->vals_backend_.getrank();
229229
assert(my_process_segment_index == segment_index_);
230+
if (dsm_->local_view == nullptr) {
231+
return nullptr;
232+
}
230233
return dsm_->local_view->begin();
231234
}
232235

0 commit comments

Comments
 (0)