@@ -423,20 +423,54 @@ std::list<Expr*> processForLoopBodies(
423423 P2PCommunicationType::RECV,
424424 slicing_output->out (),
425425 recv_peer,
426- CommunicatorBackend:: kNccl );
426+ communicator_backend );
427427 auto send = IrBuilder::create<P2PCommunication>(
428428 P2PCommunicationType::SEND,
429429 slicing_input->out (),
430430 tensor_index,
431- CommunicatorBackend::kNccl );
432- auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
433- auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
434- auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
435- if_sending_to_self->elseBody ().pushBack (start_coalescing);
436- if_sending_to_self->elseBody ().pushBack (recv);
437- if_sending_to_self->elseBody ().pushBack (send);
438- if_sending_to_self->elseBody ().pushBack (end_coalescing);
439- if_sending_to_self->elseBody ().pushBack (wait);
431+ communicator_backend);
432+ if (communicator_backend == CommunicatorBackend::kNccl ) {
433+ auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
434+ auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
435+ auto wait = IrBuilder::create<hir::Wait>(end_coalescing);
436+
437+ if_sending_to_self->elseBody ().pushBack (start_coalescing);
438+ if_sending_to_self->elseBody ().pushBack (recv);
439+ if_sending_to_self->elseBody ().pushBack (send);
440+ if_sending_to_self->elseBody ().pushBack (end_coalescing);
441+ if_sending_to_self->elseBody ().pushBack (wait);
442+ } else if (communicator_backend == CommunicatorBackend::kCuda ) {
443+ auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
444+ std::vector<P2PCommunication*>({recv, send}));
445+ auto wait_send = IrBuilder::create<hir::Wait>(send);
446+ auto wait_recv = IrBuilder::create<hir::Wait>(recv);
447+
448+ if_sending_to_self->elseBody ().pushBack (share_mem_handles);
449+ switch (getP2pProtocol ()) {
450+ case P2pProtocol::Get: {
451+ if_sending_to_self->elseBody ().pushBack (send);
452+ if_sending_to_self->elseBody ().pushBack (recv);
453+ break ;
454+ }
455+ case P2pProtocol::Put: {
456+ if_sending_to_self->elseBody ().pushBack (recv);
457+ if_sending_to_self->elseBody ().pushBack (send);
458+ break ;
459+ }
460+ }
461+ if_sending_to_self->elseBody ().pushBack (wait_recv);
462+ // Defer the wait on send to the loop epilogue under the same
463+ // predicate
464+ auto * deferred_wait_if = IrBuilder::create<kir::IfThenElse>(
465+ if_sending_to_self->input (0 )->as <kir::Predicate>());
466+ deferred_wait_if->elseBody ().pushBack (wait_send);
467+ new_loop_body_epilogue.push_back (deferred_wait_if);
468+ } else {
469+ NVF_THROW (
470+ " Unsupported communicator backend for lowering stream parallel "
471+ " type into p2p: " ,
472+ communicator_backend);
473+ }
440474 new_loop_body.push_back (slicing_input);
441475 new_loop_body.push_back (slicing_output);
442476 new_loop_body.push_back (if_sending_to_self);
@@ -533,14 +567,17 @@ std::list<Expr*> processForLoopBodies(
533567 auto wait_recv = IrBuilder::create<hir::Wait>(recv);
534568
535569 if_sending_to_self->elseBody ().pushBack (share_mem_handles);
536- if (getP2pProtocol () == P2pProtocol::Put) {
537- if_sending_to_self->elseBody ().pushBack (recv);
538- if_sending_to_self->elseBody ().pushBack (send);
539- } else if (getP2pProtocol () == P2pProtocol::Get) {
540- if_sending_to_self->elseBody ().pushBack (send);
541- if_sending_to_self->elseBody ().pushBack (recv);
542- } else {
543- NVF_ERROR (" Invalid P2P protocol: " , getP2pProtocol ());
570+ switch (getP2pProtocol ()) {
571+ case P2pProtocol::Get: {
572+ if_sending_to_self->elseBody ().pushBack (send);
573+ if_sending_to_self->elseBody ().pushBack (recv);
574+ break ;
575+ }
576+ case P2pProtocol::Put: {
577+ if_sending_to_self->elseBody ().pushBack (recv);
578+ if_sending_to_self->elseBody ().pushBack (send);
579+ break ;
580+ }
544581 }
545582 if_sending_to_self->elseBody ().pushBack (wait_recv);
546583 // Defer the wait on send to the loop epilogue under the same
0 commit comments