Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions csrc/host_ir/pass/stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,20 +423,54 @@ std::list<Expr*> processForLoopBodies(
P2PCommunicationType::RECV,
slicing_output->out(),
recv_peer,
CommunicatorBackend::kNccl);
communicator_backend);
auto send = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::SEND,
slicing_input->out(),
tensor_index,
CommunicatorBackend::kNccl);
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
if_sending_to_self->elseBody().pushBack(start_coalescing);
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(end_coalescing);
if_sending_to_self->elseBody().pushBack(wait);
communicator_backend);
if (communicator_backend == CommunicatorBackend::kNccl) {
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);

if_sending_to_self->elseBody().pushBack(start_coalescing);
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(end_coalescing);
if_sending_to_self->elseBody().pushBack(wait);
} else if (communicator_backend == CommunicatorBackend::kCuda) {
auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
std::vector<P2PCommunication*>({recv, send}));
auto wait_send = IrBuilder::create<hir::Wait>(send);
auto wait_recv = IrBuilder::create<hir::Wait>(recv);

if_sending_to_self->elseBody().pushBack(share_mem_handles);
switch (getP2pProtocol()) {
case P2pProtocol::Get: {
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(recv);
break;
}
case P2pProtocol::Put: {
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
break;
}
}
if_sending_to_self->elseBody().pushBack(wait_recv);
// Defer the wait on send to the loop epilogue under the same
// predicate
auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>(
if_sending_to_self->input(0)->as<kir::Predicate>());
deferred_wait_if->elseBody().pushBack(wait_send);
new_loop_body_epilogue.push_back(deferred_wait_if);
} else {
NVF_THROW(
"Unsupported communicator backend for lowering stream parallel "
"type into p2p: ",
communicator_backend);
}
new_loop_body.push_back(slicing_input);
new_loop_body.push_back(slicing_output);
new_loop_body.push_back(if_sending_to_self);
Expand Down Expand Up @@ -533,14 +567,17 @@ std::list<Expr*> processForLoopBodies(
auto wait_recv = IrBuilder::create<hir::Wait>(recv);

if_sending_to_self->elseBody().pushBack(share_mem_handles);
if (getP2pProtocol() == P2pProtocol::Put) {
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
} else if (getP2pProtocol() == P2pProtocol::Get) {
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(recv);
} else {
NVF_ERROR("Invalid P2P protocol: ", getP2pProtocol());
switch (getP2pProtocol()) {
case P2pProtocol::Get: {
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(recv);
break;
}
case P2pProtocol::Put: {
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
break;
}
}
if_sending_to_self->elseBody().pushBack(wait_recv);
// Defer the wait on send to the loop epilogue under the same
Expand Down
16 changes: 14 additions & 2 deletions tests/cpp/test_multidevice_stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,11 @@ TEST_F(MultiDeviceStreamParallelTypeTest, AG_matmul_P2p) {
EXPECT_TRUE(at::allclose(t2_ref, t2, 1e-2, 1e-2));
}

TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
class RSMatmulTest : public MultiDeviceStreamParallelTypeTest,
public testing::WithParamInterface<CommunicatorBackend> {};

TEST_P(RSMatmulTest, ReduceScatterP2p) {
CommunicatorBackend communicator_backend = GetParam();
constexpr int64_t M = 32;
constexpr int64_t K = 8;
constexpr int64_t N = 2;
Expand Down Expand Up @@ -538,7 +542,9 @@ TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
tv2_unreduced->axis(1)->parallelize(ParallelType::DIDx);
tv2->axis(1)->parallelize(ParallelType::DIDx);

MultiDeviceExecutor executor(std::move(fusion), *communicator_);
MultiDeviceExecutorParams params;
params.lower.communicator_backend = communicator_backend;
MultiDeviceExecutor executor(std::move(fusion), *communicator_, params);

auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
Expand All @@ -558,4 +564,10 @@ TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
<< "Output: " << t2 << " Expected: " << t2_ref;
}

INSTANTIATE_TEST_SUITE_P(
,
RSMatmulTest,
testing::Values(CommunicatorBackend::kCuda, CommunicatorBackend::kNccl),
testing::PrintToStringParamName());

} // namespace nvfuser
Loading