Skip to content

Commit 77cc792

Browse files
authored
Add FlightRecorder support for ProcessGroupXCCL (#1867)
This PR provides initial support for FlightRecorder, which allows debug trace dumps for distributed jobs. Features added: 1. Heartbeat Monitor thread which regularly checks if a dump signal has been received via pipe files, and on trigger writes traces to file (for each rank seperately) 2. Logic to record XCCL work events Compared to NCCL, we don't have some features. These could be added in a later PR: 1. No support for event timing/duration 2. No support for Watchdog thread which allows for debug dump on error or timeout, with other additional features (remote error detection, etc)
1 parent 5ee2a32 commit 77cc792

File tree

6 files changed

+404
-12
lines changed

6 files changed

+404
-12
lines changed

src/BuildOnLinux.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ macro(setup_common_libraries)
1616
if(USE_C10D_XCCL)
1717
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
1818
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
19+
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
1920
endif()
2021
list(APPEND TORCH_XPU_OPS_LIBRARIES torch_xpu_ops)
2122
endmacro()
@@ -125,6 +126,7 @@ else()
125126
if(USE_C10D_XCCL)
126127
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
127128
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
129+
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
128130
endif()
129131

130132
install(TARGETS torch_xpu_ops DESTINATION "${TORCH_INSTALL_LIB_DIR}")

src/xccl/FlightRecorderXCCL.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifdef USE_C10D_XCCL
2+
3+
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
4+
#include <ATen/xpu/XPUEvent.h>
5+
#include <xccl/ProcessGroupXCCL.hpp>
6+
7+
namespace c10d {
8+
9+
template <>
10+
float getDurationFromEvent<at::xpu::XPUEvent>(
11+
at::xpu::XPUEvent& xcclStartEvent,
12+
at::xpu::XPUEvent& xcclEndEvent) {
13+
TORCH_CHECK(
14+
xcclEndEvent.query(),
15+
"getDuration can only be called after work is succeeded.")
16+
return xcclStartEvent.elapsed_time(xcclEndEvent);
17+
}
18+
19+
template struct FlightRecorder<at::xpu::XPUEvent>;
20+
} // namespace c10d
21+
#endif // USE_C10D_XCCL

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 176 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#ifdef USE_C10D_XCCL
22

33
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
4+
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
45
#include <xccl/NanCheck_XPU.hpp>
56
#include <xccl/ProcessGroupXCCL.hpp>
67

78
namespace c10d {
89

10+
using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;
11+
912
namespace {
1013

1114
#if defined(CCL_MAJOR_VERSION) && \
@@ -200,6 +203,17 @@ void syncStream(
200203

201204
} // namespace
202205

206+
std::string dump_xccl_trace(
207+
bool includeCollectives,
208+
bool includeStackTraces,
209+
bool onlyActive) {
210+
auto xcclDumpMap = std::unordered_map<
211+
std::string,
212+
std::unordered_map<std::string, std::string>>();
213+
return FlightRecorderXCCL::get()->dump(
214+
xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
215+
}
216+
203217
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
204218
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;
205219

@@ -303,6 +317,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
303317
return true;
304318
}
305319

320+
ProcessGroupXCCL::Options::Options()
321+
: Backend::Options(XCCL_BACKEND_NAME) {}
322+
323+
306324
static std::atomic<size_t> process_group_id = 0;
307325

308326
constexpr const char* MULTI_DEVICE_ERROR_MSG =
@@ -332,19 +350,28 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
332350
ProcessGroupXCCL::ProcessGroupXCCL(
333351
const c10::intrusive_ptr<Store>& store,
334352
int rank,
335-
int size)
353+
int size,
354+
c10::intrusive_ptr<Options> options)
336355
: Backend(rank, size),
337356
store_(store),
357+
options_(std::move(options)),
338358
xcclCommCounter_(0),
339359
local_id_(process_group_id++) {
340360
logPrefix_ = createLogPrefix();
341361
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
362+
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);
363+
364+
this->setGroupUid(options_->group_name);
365+
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
366+
const auto XcclVersion = getXcclVersion();
367+
FlightRecorderXCCL::get()->record_pg_ranks(
368+
std::make_tuple(pg_uid_, pg_desc_), groupRanks());
369+
FlightRecorderXCCL::get()->record_accelerator_version(XcclVersion);
342370
enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false);
343371
init();
344372
const std::string OFF = "OFF";
345373
std::string torch_distributed_debug =
346374
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
347-
const auto XcclVersion = getXcclVersion();
348375
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
349376
<< "size: " << size << ", global rank: " << rank_;
350377

@@ -353,9 +380,63 @@ ProcessGroupXCCL::ProcessGroupXCCL(
353380
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
354381
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
355382
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
383+
384+
// Heartbeat monitor thread dumps debug info on write to pipe
385+
heartbeatMonitor_ = std::make_unique<HeartbeatMonitorXCCL>(this);
386+
heartbeatMonitor_->start();
387+
}
388+
389+
ProcessGroupXCCL::~ProcessGroupXCCL() {
390+
heartbeatMonitor_->stop();
391+
// Wait for all threads to finish before returning
392+
heartbeatMonitor_->join();
356393
}
357394

358-
ProcessGroupXCCL::~ProcessGroupXCCL() = default;
395+
bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
396+
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupXCCL__dumpDebuggingInfo);
397+
LOG(ERROR)
398+
<< logPrefix()
399+
<< "ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
400+
<< includeStackTrace;
401+
if (traceBufferSize_ > 0) {
402+
// TODO: dump_xccl_trace
403+
auto xcclTrace = dump_xccl_trace(true, includeStackTrace, false);
404+
DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_);
405+
LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to "
406+
<< writer.getWriterTarget();
407+
writer.write(xcclTrace);
408+
LOG(INFO) << logPrefix() << "Flight Recorder trace successfully dumped.";
409+
return true;
410+
}
411+
return false;
412+
}
413+
414+
const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
415+
if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
416+
static std::vector<uint64_t> globalRanks(size_);
417+
std::iota(globalRanks.begin(), globalRanks.end(), 0);
418+
return globalRanks;
419+
}
420+
return options_->global_ranks_in_group;
421+
}
422+
423+
void ProcessGroupXCCL::setEnqueuedPgStatus(
424+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
425+
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->getSequencenumber());
426+
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
427+
pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
428+
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
429+
}
430+
431+
void ProcessGroupXCCL::setCompletedPgStatus(
432+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
433+
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work->getSequencenumber());
434+
pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_);
435+
pgStatus_->lastCompletedNumelIn = work->numelIn_;
436+
pgStatus_->lastCompletedNumelOut = work->numelOut_;
437+
// To avoid complexity, we're not computing duration.
438+
FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false);
439+
}
359440

360441
void ProcessGroupXCCL::setSequenceNumberForGroup() {}
361442

@@ -384,6 +465,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
384465
profilingTitle,
385466
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
386467
: std::nullopt);
468+
469+
r->trace_id_ = FlightRecorderXCCL::get()->record(
470+
local_id_,
471+
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
472+
seqCollective_,
473+
seqP2P_,
474+
op_id_,
475+
profilingTitle ? profilingTitle : "",
476+
inputs,
477+
outputs,
478+
nullptr,
479+
r->xcclEndEvent_.get(),
480+
options_->timeout,
481+
pgStatus_,
482+
isP2P);
387483
return r;
388484
}
389485

@@ -538,6 +634,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
538634
groupEnd();
539635

540636
work->xcclEndEvent_->record(stream);
637+
setEnqueuedPgStatus(work);
541638

542639
coalescing_state_ = 0;
543640
coalescedComm_ = nullptr;
@@ -572,6 +669,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
572669
if ((coalescing_state_ & CoalColl) == 0) {
573670
seqCollective_++;
574671
}
672+
op_id_++;
575673
coalescing_state_ |= CoalColl;
576674
if (coalescedDevice_.index() < 0) {
577675
coalescedDevice_ = device;
@@ -614,6 +712,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
614712
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
615713
work =
616714
initWork(device, rank_, opType, false, profilingTitle, inputs, outputs);
715+
if (coalescing_state_) {
716+
FlightRecorderXCCL::get()->record(
717+
local_id_,
718+
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
719+
seqCollective_,
720+
seqP2P_,
721+
op_id_,
722+
profilingTitle ? profilingTitle : "",
723+
inputs,
724+
outputs,
725+
nullptr,
726+
nullptr,
727+
options_->timeout,
728+
pgStatus_,
729+
false);
730+
}
617731

618732
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
619733

@@ -653,8 +767,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
653767
work->future_ = c10::make_intrusive<at::ivalue::Future>(
654768
c10::ListType::create(c10::TensorType::get()), devices);
655769
work->future_->markCompleted(at::IValue(*work->outputs_));
770+
work->future_->addCallback(
771+
[this, work](at::ivalue::Future&) {
772+
this->setCompletedPgStatus(work);
773+
});
656774
work->blockingWait_ = blockingWait_;
657775

776+
work->numelIn_ = 0;
777+
work->numelOut_ = 0;
778+
for (const auto& input : inputs) {
779+
work->numelIn_ += input.numel();
780+
}
781+
for (const auto& output : outputs) {
782+
work->numelOut_ += output.numel();
783+
}
784+
setEnqueuedPgStatus(work);
785+
658786
return asyncOp ? work : nullptr;
659787
}
660788

@@ -687,6 +815,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
687815
}
688816
}
689817

818+
op_id_++;
690819
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
691820

692821
if (coalescing_state_ & CoalActive) {
@@ -722,6 +851,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
722851
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
723852
work->outputs_->push_back(tensor);
724853

854+
work->trace_id_ = FlightRecorderXCCL::get()->record(
855+
local_id_,
856+
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
857+
seqCollective_,
858+
seqP2P_,
859+
op_id_,
860+
profilingTitle,
861+
{tensor},
862+
{tensor},
863+
nullptr,
864+
work->xcclEndEvent_.get(),
865+
options_->timeout,
866+
pgStatus_,
867+
true);
868+
725869
c10::OptionalDeviceGuard gpuGuard(device);
726870

727871
c10::xpu::XPUCachingAllocator::recordStream(
@@ -737,8 +881,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
737881
work->future_ = c10::make_intrusive<at::ivalue::Future>(
738882
c10::ListType::create(c10::TensorType::get()), devices);
739883
work->future_->markCompleted(at::IValue(*work->outputs_));
884+
work->future_->addCallback(
885+
[this, work](at::ivalue::Future&) {
886+
this->setCompletedPgStatus(work);
887+
});
888+
889+
work->numelIn_ = work->numelOut_ = tensor.numel();
890+
setEnqueuedPgStatus(work);
740891
return work;
741892
} else {
893+
FlightRecorderXCCL::get()->record(
894+
local_id_,
895+
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
896+
seqCollective_,
897+
seqP2P_,
898+
op_id_,
899+
profilingTitle,
900+
{tensor},
901+
{tensor},
902+
nullptr,
903+
nullptr,
904+
options_->timeout,
905+
pgStatus_,
906+
true);
742907
c10::OptionalDeviceGuard gpuGuard(device);
743908

744909
c10::xpu::XPUCachingAllocator::recordStream(
@@ -2135,6 +2300,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
21352300
"xccl:all_to_all");
21362301
}
21372302

2303+
std::string getXcclVersion() {
2304+
auto xccl_version = ccl::get_library_version();
2305+
std::string versionString = std::to_string(xccl_version.major) + "." +
2306+
std::to_string(xccl_version.minor) + "." +
2307+
std::to_string(xccl_version.update);
2308+
return versionString;
2309+
}
2310+
21382311
} // namespace c10d
21392312

21402313
#endif // USE_C10D_XCCL

0 commit comments

Comments
 (0)