Skip to content

Commit d9da590

Browse files
committed
fix gpu and do not implicitly set0 the result tile
not yet as continuation
1 parent 1c06b55 commit d9da590

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

include/dlaf/eigensolver/reduction_to_band/impl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,6 @@ void gemmComputeW2(matrix::Matrix<T, D>& w2, matrix::Panel<Coord::Col, const T,
462462
//// Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in
463463
//// the column are going to participate to the reduce. For them, it is important to set the
464464
//// partial result W2 to zero.
465-
// ex::start_detached(w2.readwrite_sender(LocalTileIndex(0, 0)) |
466-
// tile::set0(dlaf::internal::Policy<B>(thread_priority::high)));
467465

468466
using namespace blas;
469467
// GEMM W2 = W* . X
@@ -474,6 +472,7 @@ void gemmComputeW2(matrix::Matrix<T, D>& w2, matrix::Panel<Coord::Col, const T,
474472
tile::gemm(dlaf::internal::Policy<B>(thread_priority::high)));
475473
}
476474

475+
ex::start_detached(tile::set0(dlaf::internal::Policy<B>(), w2.readwrite_sender(LocalTileIndex(0, 0))));
477476
ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0))));
478477
}
479478

include/dlaf/matrix/extra_buffers.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,51 @@ struct ExtraBuffers : protected Matrix<T, D> {
2929
pika::execution::thread_priority::high)));
3030
}
3131

32+
auto read_sender(SizeType index) {
33+
return Matrix<T, D>::read_sender(internalIndex(index));
34+
}
35+
3236
auto readwrite_sender(SizeType index) {
33-
index %= nbuffers_;
34-
return Matrix<T, D>::readwrite_sender(LocalTileIndex{index, 0});
37+
return Matrix<T, D>::readwrite_sender(internalIndex(index));
3538
}
3639

3740
template <class TileSender>
3841
[[nodiscard]] auto reduce(TileSender tile) {
3942
namespace ex = pika::execution::experimental;
4043

41-
std::vector<pika::future<matrix::Tile<T, D>>> buffers;
42-
for (const auto& ij : common::iterate_range2d(this->distribution().localNrTiles()))
43-
buffers.emplace_back(Matrix<T, D>::operator()(ij));
44-
auto all_buffers = ex::when_all_vector(std::move(buffers));
45-
46-
return ex::when_all(std::move(tile), std::move(all_buffers)) |
47-
ex::then([](const matrix::Tile<T, D>& tile, const std::vector<matrix::Tile<T, D>>& buffers) {
48-
tile::internal::set0(tile);
49-
for (auto& buffer : buffers)
50-
dlaf::tile::internal::add(T(1), buffer, tile);
51-
});
44+
std::vector<ex::any_sender<pika::shared_future<matrix::Tile<const T, D>>>> buffers;
45+
for (SizeType index = 0; index < nbuffers_; ++index)
46+
buffers.emplace_back(read_sender(index));
47+
48+
return ex::when_all(std::move(tile), ex::when_all_vector(std::move(buffers))) |
49+
dlaf::internal::transform(dlaf::internal::Policy<DefaultBackend_v<D>>(),
50+
[](const matrix::Tile<T, D>& tile,
51+
const std::vector<pika::shared_future<matrix::Tile<const T, D>>>&
52+
buffers,
53+
auto&&... ts) {
54+
for (const auto& buffer : buffers) {
55+
if constexpr (D == Device::CPU) {
56+
static_assert(sizeof...(ts) == 0,
57+
"Parameter pack should be empty for MC.");
58+
dlaf::tile::internal::add(T(1), buffer.get(), tile);
59+
}
60+
#ifdef DLAF_WITH_GPU
61+
else if constexpr (D == Device::GPU) {
62+
dlaf::tile::internal::add(T(1), buffer.get(), tile, ts...);
63+
}
64+
#endif
65+
else {
66+
DLAF_STATIC_UNIMPLEMENTED(T);
67+
}
68+
}
69+
});
5270
}
5371

5472
protected:
73+
LocalTileIndex internalIndex(SizeType index) const noexcept {
74+
return LocalTileIndex{index % nbuffers_, 0};
75+
}
76+
5577
SizeType nbuffers_;
5678
};
5779
}

0 commit comments

Comments
 (0)