Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.

Commit b7704ea

Browse files
author
Mikolaj Komar
committed
Add local to csr_eq_segment
1 parent 55185dc commit b7704ea

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

include/dr/mp/algorithms/reduce.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ inline auto dpl_reduce(rng::forward_range auto &&r, auto &&binary_op) {
3535
sycl::known_identity_v<Fn, T>, binary_op);
3636
} else {
3737
dr::drlog.debug(" peel 1st value\n");
38+
auto base = *rng::begin(r);
3839
return std::reduce(dpl_policy(),
3940
dr::__detail::direct_iterator(rng::begin(r) + 1),
4041
dr::__detail::direct_iterator(rng::end(r)),
41-
sycl_get(*rng::begin(r)), binary_op);
42+
sycl_get(base), binary_op);
4243
}
4344
}
4445
#else

include/dr/mp/containers/matrix_formats/csr_eq_segment.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,36 @@ template <typename DSM> class csr_eq_segment_iterator {
235235
return dr::__detail::drop_segments(dsm_->segments(), segment_index_,
236236
index_);
237237
}
238+
239+
auto local() const {
240+
const auto my_process_segment_index = dsm_->rows_backend_.getrank();
241+
242+
assert(my_process_segment_index == segment_index_);
243+
// auto offset = dsm_->row_offsets_[segment_index_];
244+
// auto row_size = dsm_->row_size_;
245+
auto segment_size = dsm_->vals_data_->segment_size();
246+
auto local_vals = dsm_->vals_data_->segments()[segment_index_].begin().local();
247+
auto local_vals_range = rng::subrange(local_vals, local_vals + segment_size);
248+
auto local_cols = dsm_->cols_data_->segments()[segment_index_].begin().local();
249+
auto local_cols_range = rng::subrange(local_cols, local_cols + segment_size);
250+
// auto local_rows = dsm_->rows_data_;
251+
auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
252+
auto enumerated_zipped = rng::views::enumerate(zipped_results);
253+
auto transformer = [&](auto entry) {
254+
auto [index, pair] = entry;
255+
auto [val, column] = pair;
256+
auto row = 0; //TODO fix calculating row - it results in segfault
257+
// auto row = rng::distance(
258+
// local_rows,
259+
// std::upper_bound(local_rows, local_rows + row_size, offset + index) -
260+
// 1);
261+
dr::index<index_type> index_obj(row, column);
262+
value_type entry_obj(index_obj, val);
263+
return entry_obj;
264+
};
265+
auto transformed_res = rng::transform_view(enumerated_zipped, transformer);
266+
return transformed_res.begin();
267+
}
238268

239269
private:
240270
// all fields need to be initialized by default ctor so every default

0 commit comments

Comments
 (0)