|
4 | 4 |
|
5 | 5 | #pragma once |
6 | 6 |
|
| 7 | + |
| 8 | +int some_id_base =0; |
7 | 9 | namespace dr::mp { |
| 10 | +namespace __detail { |
| 11 | + template <typename T, typename V> |
| 12 | + class transform_fn_1 { |
| 13 | + public: |
| 14 | + using value_type = V; |
| 15 | + using index_type = T; |
| 16 | + transform_fn_1(std::size_t offset, std::size_t row_size, T* row_ptr): |
| 17 | + offset_(offset), row_size_(row_size), row_ptr_(row_ptr) { |
| 18 | + assert(offset_ == 0); |
| 19 | + myid = some_id_base++; |
| 20 | + fmt::print("created {}\n", myid); |
| 21 | + } |
8 | 22 |
|
| 23 | + ~transform_fn_1() { |
| 24 | + destroyed = true; |
| 25 | + fmt::print("destroyed {}\n", myid); |
| 26 | + } |
| 27 | + template <typename P> |
| 28 | + auto operator()(P entry) const { |
| 29 | + fmt::print("called {}\n", myid); |
| 30 | + assert(offset_ == 0); |
| 31 | + assert(!destroyed); |
| 32 | + auto [index, pair] = entry; |
| 33 | + auto [val, column] = pair; |
| 34 | + auto row = rng::distance( |
| 35 | + row_ptr_, |
| 36 | + std::upper_bound(row_ptr_, row_ptr_ + row_size_, offset_) - |
| 37 | + 1); |
| 38 | + dr::index<index_type> index_obj(row, column); |
| 39 | + value_type entry_obj(index_obj, val); |
| 40 | + return entry_obj; |
| 41 | + } |
| 42 | + private: |
| 43 | + int myid = 0; |
| 44 | + bool destroyed = false; |
| 45 | + std::size_t offset_; |
| 46 | + std::size_t row_size_; |
| 47 | + T* row_ptr_; |
| 48 | + }; |
| 49 | +} |
9 | 50 | template <typename DSM> class csr_row_segment_iterator; |
10 | 51 |
|
11 | 52 | template <typename DSM> class csr_row_segment_reference { |
@@ -55,6 +96,10 @@ template <typename DSM> class csr_row_segment_iterator { |
55 | 96 | dsm_ = dsm; |
56 | 97 | segment_index_ = segment_index; |
57 | 98 | index_ = index; |
| 99 | + if (dsm_->vals_backend_.getrank() == segment_index_) { |
| 100 | + elem_view_ = get_elem_view(dsm_, segment_index); |
| 101 | + base_iter = elem_view_.begin(); |
| 102 | + } |
58 | 103 | } |
59 | 104 |
|
60 | 105 | auto operator<=>(const csr_row_segment_iterator &other) const noexcept { |
@@ -227,40 +272,36 @@ template <typename DSM> class csr_row_segment_iterator { |
227 | 272 |
|
228 | 273 | auto local() const { |
229 | 274 | const auto my_process_segment_index = dsm_->vals_backend_.getrank(); |
230 | | - |
231 | 275 | assert(my_process_segment_index == segment_index_); |
232 | | - std::size_t offset = dsm_->segment_size_ * segment_index_; |
233 | | - assert(offset == 0); |
234 | | - // auto row_size = dsm_->segment_size_; |
235 | | - auto vals_size = dsm_->vals_size_; |
236 | | - auto local_vals = dsm_->vals_data_; |
| 276 | + auto [a, b] = *base_iter; |
| 277 | + auto [c, d] = a; |
| 278 | + fmt::print("aqwsedrftgyhuji {} {} {}\n", b, c, d); |
| 279 | + return base_iter; |
| 280 | + } |
| 281 | + |
| 282 | +private: |
| 283 | + |
| 284 | + static auto get_elem_view(DSM *dsm, std::size_t segment_index) { |
| 285 | + std::size_t offset = dsm->segment_size_ * segment_index; |
| 286 | + auto row_size = dsm->segment_size_; |
| 287 | + auto vals_size = dsm->vals_size_; |
| 288 | + auto local_vals = dsm->vals_data_; |
237 | 289 | auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size); |
238 | | - auto local_cols = dsm_->cols_data_; |
| 290 | + auto local_cols = dsm->cols_data_; |
239 | 291 | auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size); |
240 | | - // auto local_rows = dsm_->rows_data_->segments()[segment_index_].begin().local(); |
| 292 | + auto local_rows = dsm->rows_data_->segments()[segment_index].begin().local(); |
241 | 293 | auto zipped_results = rng::views::zip(local_vals_range, local_cols_range); |
242 | 294 | auto enumerated_zipped = rng::views::enumerate(zipped_results); |
243 | | - auto transformer = [=](auto entry) { |
244 | | - assert(offset == 0); |
245 | | - auto [index, pair] = entry; |
246 | | - auto [val, column] = pair; |
247 | | - auto row = 0; //TODO fix calculating row - it results in segfault |
248 | | - // problem originates from the fact that variables cannot be caputed properly by value |
249 | | - // auto row = rng::distance( |
250 | | - // local_rows, |
251 | | - // std::upper_bound(local_rows, local_rows + row_size, offset) - |
252 | | - // 1); |
253 | | - dr::index<index_type> index_obj(row, column); |
254 | | - value_type entry_obj(index_obj, val); |
255 | | - return entry_obj; |
256 | | - }; |
257 | | - auto transformed_res = rng::views::transform(enumerated_zipped, transformer); |
258 | | - return transformed_res.begin(); |
| 295 | + auto transformer = __detail::transform_fn_1<index_type, value_type>(offset, row_size, local_rows); |
| 296 | + return rng::views::transform(enumerated_zipped, transformer); |
259 | 297 | } |
260 | 298 |
|
261 | | -private: |
262 | 299 | // all fields need to be initialized by default ctor so every default |
263 | 300 | // constructed iter is equal to any other default constructed iter |
| 301 | + using view_type = decltype(get_elem_view(std::declval<DSM*>(), 0)); |
| 302 | + using iter_type = rng::iterator_t<view_type>; |
| 303 | + view_type elem_view_; |
| 304 | + iter_type base_iter; |
264 | 305 | DSM *dsm_ = nullptr; |
265 | 306 | std::size_t segment_index_ = 0; |
266 | 307 | std::size_t index_ = 0; |
|
0 commit comments