@@ -29,29 +29,51 @@ struct ExtraBuffers : protected Matrix<T, D> {
2929 pika::execution::thread_priority::high)));
3030 }
3131
32+ auto read_sender (SizeType index) {
33+ return Matrix<T, D>::read_sender (internalIndex (index));
34+ }
35+
3236 auto readwrite_sender (SizeType index) {
33- index %= nbuffers_;
34- return Matrix<T, D>::readwrite_sender (LocalTileIndex{index, 0 });
37+ return Matrix<T, D>::readwrite_sender (internalIndex (index));
3538 }
3639
3740 template <class TileSender >
3841 [[nodiscard]] auto reduce (TileSender tile) {
3942 namespace ex = pika::execution::experimental;
4043
41- std::vector<pika::future<matrix::Tile<T, D>>> buffers;
42- for (const auto & ij : common::iterate_range2d (this ->distribution ().localNrTiles ()))
43- buffers.emplace_back (Matrix<T, D>::operator ()(ij));
44- auto all_buffers = ex::when_all_vector (std::move (buffers));
45-
46- return ex::when_all (std::move (tile), std::move (all_buffers)) |
47- ex::then ([](const matrix::Tile<T, D>& tile, const std::vector<matrix::Tile<T, D>>& buffers) {
48- tile::internal::set0 (tile);
49- for (auto & buffer : buffers)
50- dlaf::tile::internal::add (T (1 ), buffer, tile);
51- });
44+ std::vector<ex::any_sender<pika::shared_future<matrix::Tile<const T, D>>>> buffers;
45+ for (SizeType index = 0 ; index < nbuffers_; ++index)
46+ buffers.emplace_back (read_sender (index));
47+
48+ return ex::when_all (std::move (tile), ex::when_all_vector (std::move (buffers))) |
49+ dlaf::internal::transform (dlaf::internal::Policy<DefaultBackend_v<D>>(),
50+ [](const matrix::Tile<T, D>& tile,
51+ const std::vector<pika::shared_future<matrix::Tile<const T, D>>>&
52+ buffers,
53+ auto &&... ts) {
54+ for (const auto & buffer : buffers) {
55+ if constexpr (D == Device::CPU) {
56+ static_assert (sizeof ...(ts) == 0 ,
57+ " Parameter pack should be empty for MC." );
58+ dlaf::tile::internal::add (T (1 ), buffer.get (), tile);
59+ }
60+ #ifdef DLAF_WITH_GPU
61+ else if constexpr (D == Device::GPU) {
62+ dlaf::tile::internal::add (T (1 ), buffer.get (), tile, ts...);
63+ }
64+ #endif
65+ else {
66+ DLAF_STATIC_UNIMPLEMENTED (T);
67+ }
68+ }
69+ });
5270 }
5371
5472protected:
73+ LocalTileIndex internalIndex (SizeType index) const noexcept {
74+ return LocalTileIndex{index % nbuffers_, 0 };
75+ }
76+
5577 SizeType nbuffers_;
5678};
5779}
0 commit comments