Skip to content

Commit 0201ccc

Browse files
authored
fix p2p comm memory release logic (#47497) (#47517)
1 parent 4b3589f commit 0201ccc

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

paddle/fluid/distributed/collective/ProcessGroupNCCL.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
448448

449449
platform::CUDADeviceGuard cuda_guard;
450450

451-
if (FLAGS_use_stream_safe_cuda_allocator) {
451+
{
452+
platform::NCCLGroupGuard nccl_guard;
452453
for (size_t i = 0; i < tensors.size(); ++i) {
453454
cuda_guard.SetDevice(places[i]);
454455
gpuStream_t nccl_stream;
@@ -460,12 +461,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
460461
} else {
461462
nccl_stream = places_to_ctx_[key][i]->stream();
462463
}
463-
memory::RecordStream(tensors[i].Holder(), nccl_stream);
464+
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
464465
}
465466
}
466467

467-
{
468-
platform::NCCLGroupGuard nccl_guard;
468+
if (FLAGS_use_stream_safe_cuda_allocator) {
469469
for (size_t i = 0; i < tensors.size(); ++i) {
470470
cuda_guard.SetDevice(places[i]);
471471
gpuStream_t nccl_stream;
@@ -477,7 +477,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
477477
} else {
478478
nccl_stream = places_to_ctx_[key][i]->stream();
479479
}
480-
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
480+
memory::RecordStream(tensors[i].Holder(), nccl_stream);
481481
}
482482
}
483483

@@ -516,20 +516,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
516516
// construct uninitialize guard for device
517517
platform::CUDADeviceGuard cuda_guard;
518518

519-
if (FLAGS_use_stream_safe_cuda_allocator) {
519+
{
520+
platform::NCCLGroupGuard nccl_guard;
520521
for (size_t i = 0; i < tensors.size(); ++i) {
521522
cuda_guard.SetDevice(places[i]);
522-
memory::RecordStream(tensors[i].Holder(),
523-
places_to_ctx_[key][i]->stream());
523+
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
524+
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
524525
}
525526
}
526527

527-
{
528-
platform::NCCLGroupGuard nccl_guard;
528+
if (FLAGS_use_stream_safe_cuda_allocator) {
529529
for (size_t i = 0; i < tensors.size(); ++i) {
530530
cuda_guard.SetDevice(places[i]);
531-
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
532-
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
531+
memory::RecordStream(tensors[i].Holder(),
532+
places_to_ctx_[key][i]->stream());
533533
}
534534
}
535535

0 commit comments

Comments
 (0)