Skip to content

Commit 1bbd375

Browse files
authored
Add cuda backend support for MM+RS stream lowering (#5761)
1 parent 06ed9cb commit 1bbd375

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

csrc/host_ir/pass/stream_parallel_type.cpp

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -423,20 +423,54 @@ std::list<Expr*> processForLoopBodies(
423423
P2PCommunicationType::RECV,
424424
slicing_output->out(),
425425
recv_peer,
426-
CommunicatorBackend::kNccl);
426+
communicator_backend);
427427
auto send = IrBuilder::create<P2PCommunication>(
428428
P2PCommunicationType::SEND,
429429
slicing_input->out(),
430430
tensor_index,
431-
CommunicatorBackend::kNccl);
432-
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
433-
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
434-
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
435-
if_sending_to_self->elseBody().pushBack(start_coalescing);
436-
if_sending_to_self->elseBody().pushBack(recv);
437-
if_sending_to_self->elseBody().pushBack(send);
438-
if_sending_to_self->elseBody().pushBack(end_coalescing);
439-
if_sending_to_self->elseBody().pushBack(wait);
431+
communicator_backend);
432+
if (communicator_backend == CommunicatorBackend::kNccl) {
433+
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
434+
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
435+
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
436+
437+
if_sending_to_self->elseBody().pushBack(start_coalescing);
438+
if_sending_to_self->elseBody().pushBack(recv);
439+
if_sending_to_self->elseBody().pushBack(send);
440+
if_sending_to_self->elseBody().pushBack(end_coalescing);
441+
if_sending_to_self->elseBody().pushBack(wait);
442+
} else if (communicator_backend == CommunicatorBackend::kCuda) {
443+
auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
444+
std::vector<P2PCommunication*>({recv, send}));
445+
auto wait_send = IrBuilder::create<hir::Wait>(send);
446+
auto wait_recv = IrBuilder::create<hir::Wait>(recv);
447+
448+
if_sending_to_self->elseBody().pushBack(share_mem_handles);
449+
switch (getP2pProtocol()) {
450+
case P2pProtocol::Get: {
451+
if_sending_to_self->elseBody().pushBack(send);
452+
if_sending_to_self->elseBody().pushBack(recv);
453+
break;
454+
}
455+
case P2pProtocol::Put: {
456+
if_sending_to_self->elseBody().pushBack(recv);
457+
if_sending_to_self->elseBody().pushBack(send);
458+
break;
459+
}
460+
}
461+
if_sending_to_self->elseBody().pushBack(wait_recv);
462+
// Defer the wait on send to the loop epilogue under the same
463+
// predicate
464+
auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>(
465+
if_sending_to_self->input(0)->as<kir::Predicate>());
466+
deferred_wait_if->elseBody().pushBack(wait_send);
467+
new_loop_body_epilogue.push_back(deferred_wait_if);
468+
} else {
469+
NVF_THROW(
470+
"Unsupported communicator backend for lowering stream parallel "
471+
"type into p2p: ",
472+
communicator_backend);
473+
}
440474
new_loop_body.push_back(slicing_input);
441475
new_loop_body.push_back(slicing_output);
442476
new_loop_body.push_back(if_sending_to_self);
@@ -533,14 +567,17 @@ std::list<Expr*> processForLoopBodies(
533567
auto wait_recv = IrBuilder::create<hir::Wait>(recv);
534568

535569
if_sending_to_self->elseBody().pushBack(share_mem_handles);
536-
if (getP2pProtocol() == P2pProtocol::Put) {
537-
if_sending_to_self->elseBody().pushBack(recv);
538-
if_sending_to_self->elseBody().pushBack(send);
539-
} else if (getP2pProtocol() == P2pProtocol::Get) {
540-
if_sending_to_self->elseBody().pushBack(send);
541-
if_sending_to_self->elseBody().pushBack(recv);
542-
} else {
543-
NVF_ERROR("Invalid P2P protocol: ", getP2pProtocol());
570+
switch (getP2pProtocol()) {
571+
case P2pProtocol::Get: {
572+
if_sending_to_self->elseBody().pushBack(send);
573+
if_sending_to_self->elseBody().pushBack(recv);
574+
break;
575+
}
576+
case P2pProtocol::Put: {
577+
if_sending_to_self->elseBody().pushBack(recv);
578+
if_sending_to_self->elseBody().pushBack(send);
579+
break;
580+
}
544581
}
545582
if_sending_to_self->elseBody().pushBack(wait_recv);
546583
// Defer the wait on send to the loop epilogue under the same

tests/cpp/test_multidevice_stream_parallel_type.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,11 @@ TEST_F(MultiDeviceStreamParallelTypeTest, AG_matmul_P2p) {
485485
EXPECT_TRUE(at::allclose(t2_ref, t2, 1e-2, 1e-2));
486486
}
487487

488-
TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
488+
class RSMatmulTest : public MultiDeviceStreamParallelTypeTest,
489+
public testing::WithParamInterface<CommunicatorBackend> {};
490+
491+
TEST_P(RSMatmulTest, ReduceScatterP2p) {
492+
CommunicatorBackend communicator_backend = GetParam();
489493
constexpr int64_t M = 32;
490494
constexpr int64_t K = 8;
491495
constexpr int64_t N = 2;
@@ -538,7 +542,9 @@ TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
538542
tv2_unreduced->axis(1)->parallelize(ParallelType::DIDx);
539543
tv2->axis(1)->parallelize(ParallelType::DIDx);
540544

541-
MultiDeviceExecutor executor(std::move(fusion), *communicator_);
545+
MultiDeviceExecutorParams params;
546+
params.lower.communicator_backend = communicator_backend;
547+
MultiDeviceExecutor executor(std::move(fusion), *communicator_, params);
542548

543549
auto tensor_options =
544550
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
@@ -558,4 +564,10 @@ TEST_F(MultiDeviceStreamParallelTypeTest, ReduceScatterP2p) {
558564
<< "Output: " << t2 << " Expected: " << t2_ref;
559565
}
560566

567+
INSTANTIATE_TEST_SUITE_P(
568+
,
569+
RSMatmulTest,
570+
testing::Values(CommunicatorBackend::kCuda, CommunicatorBackend::kNccl),
571+
testing::PrintToStringParamName());
572+
561573
} // namespace nvfuser

0 commit comments

Comments
 (0)