Skip to content

Add FlightRecorder support for ProcessGroupXCCL #1867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/BuildOnLinux.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ macro(setup_common_libraries)
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()
list(APPEND TORCH_XPU_OPS_LIBRARIES torch_xpu_ops)
endmacro()
Expand Down Expand Up @@ -125,6 +126,7 @@ else()
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()

install(TARGETS torch_xpu_ops DESTINATION "${TORCH_INSTALL_LIB_DIR}")
Expand Down
21 changes: 21 additions & 0 deletions src/xccl/FlightRecorderXCCL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <ATen/xpu/XPUEvent.h>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

template <>
float getDurationFromEvent<at::xpu::XPUEvent>(
at::xpu::XPUEvent& xcclStartEvent,
at::xpu::XPUEvent& xcclEndEvent) {
TORCH_CHECK(
xcclEndEvent.query(),
"getDuration can only be called after work is succeeded.")
return xcclStartEvent.elapsed_time(xcclEndEvent);
}

template struct FlightRecorder<at::xpu::XPUEvent>;
} // namespace c10d
#endif // USE_C10D_XCCL
218 changes: 216 additions & 2 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;

namespace {

#if defined(CCL_MAJOR_VERSION) && \
Expand Down Expand Up @@ -302,6 +305,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

ProcessGroupXCCL::Options::Options()
: Backend::Options(XCCL_BACKEND_NAME) {}


static std::atomic<size_t> process_group_id = 0;

constexpr const char* MULTI_DEVICE_ERROR_MSG =
Expand Down Expand Up @@ -331,13 +338,21 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size)
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size),
store_(store),
options_(std::move(options)),
xcclCommCounter_(0),
local_id_(process_group_id++) {
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);

this->setGroupUid(options_->group_name);
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
FlightRecorderXCCL::get()->record_pg_ranks(
std::make_tuple(pg_uid_, pg_desc_), groupRanks());
init();
const std::string OFF = "OFF";
std::string torch_distributed_debug =
Expand All @@ -350,9 +365,124 @@ ProcessGroupXCCL::ProcessGroupXCCL(
<< "XCCL version: " << XcclVersion
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug;

// Heartbeat monitor thread dumps debug info on write to pipe
heartbeatMonitor_ = std::make_unique<HeartbeatMonitor>(this);
heartbeatMonitor_->start();
}

ProcessGroupXCCL::~ProcessGroupXCCL() {
heartbeatMonitor_->stop();
// Wait for all threads to finish before returning
heartbeatMonitor_->join();
}

bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupXCCL__dumpDebuggingInfo);
LOG(ERROR)
<< logPrefix()
<< "ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
<< includeStackTrace;
if (traceBufferSize_ > 0) {
// TODO: dump_xccl_trace
auto xcclDumpMap = std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>();
auto xcclTrace = FlightRecorderXCCL::get()->dump(
xcclDumpMap, true, includeStackTrace, false);
DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_);
LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to "
<< writer.getWriterTarget();
writer.write(xcclTrace);
LOG(INFO) << logPrefix() << "Flight Recorder trace successfully dumped.";
return true;
}
return false;
}

ProcessGroupXCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupXCCL* pg) {
pg_ = pg;
coordCheckIntervalMilSec_ = getCvarInt(TORCH_XCCL_COORD_CHECK_MILSEC, 1000);
LOG(INFO)
<< pg_->logPrefix() << "HeartbeatMonitor environments: "
<< "TORCH_XCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_;
}

void ProcessGroupXCCL::HeartbeatMonitor::stop() {
terminateHeartbeatMonitorThread_.store(true);
monitorWakeUpCV_.notify_one();
}

void ProcessGroupXCCL::HeartbeatMonitor::start() {
TORCH_CHECK(
!xcclHeartbeatMonitorThread_.joinable(),
"HeartbeatMonitor thread already started");
xcclHeartbeatMonitorThread_ =
std::thread(&ProcessGroupXCCL::HeartbeatMonitor::runLoop, this);
}

void ProcessGroupXCCL::HeartbeatMonitor::join() {
if (xcclHeartbeatMonitorThread_.joinable()) {
xcclHeartbeatMonitorThread_.join();
LOG(INFO) << pg_->logPrefix()
<< "ProcessGroupXCCL heart beat monitor thread joined.";
}
}

void ProcessGroupXCCL::HeartbeatMonitor::runLoop() {
c10::setThreadName("pt_xccl_heartbt");

std::optional<DumpPipe> dumpPipe = std::nullopt;
// We only need to dump once per PG, so we use local_id_ == 0 for the first PG
if (pg_->local_id_ == 0) {
// DumpPipe is one per-trainer process
dumpPipe.emplace(pg_->rank_);
while (true) {
std::unique_lock<std::mutex> lock(monitorMutex_);
if (monitorWakeUpCV_.wait_for(
lock, std::chrono::milliseconds(coordCheckIntervalMilSec_), [&] {
return terminateHeartbeatMonitorThread_.load();
})) {
return;
}
// Write to pipe files for all ranks to dump debug info
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
LOG(INFO) << pg_->logPrefix()
<< "Dump signal received through pipe, triggering FR dump.";
std::future<bool> fut = std::async(std::launch::async, [this]() {
return this->pg_->dumpDebuggingInfo();
});
}
}
}
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
if (options_->global_ranks_in_group.empty()) {
static std::vector<uint64_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}

void ProcessGroupXCCL::setStartedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastStartedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastStartedWorkName = opTypeToString(work->opType_);
pgStatus_->lastStartedNumelIn = work->numelIn_;
pgStatus_->lastStartedNumelOut = work->numelOut_;
}

void ProcessGroupXCCL::setCompletedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_);
pgStatus_->lastCompletedNumelIn = work->numelIn_;
pgStatus_->lastCompletedNumelOut = work->numelOut_;
// To avoid complexity, we're not computing duration.
FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false);
}

void ProcessGroupXCCL::setSequenceNumberForGroup() {}

Expand All @@ -377,6 +507,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
profilingTitle,
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
: std::nullopt);

r->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
r->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
isP2P);
return r;
}

Expand Down Expand Up @@ -531,6 +676,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
groupEnd();

work->xcclEndEvent_->record(stream);
setStartedPgStatus(work);

coalescing_state_ = 0;
coalescedComm_ = nullptr;
Expand Down Expand Up @@ -563,6 +709,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
if ((coalescing_state_ & CoalColl) == 0) {
seqCollective_++;
}
op_id_++;
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = device;
Expand Down Expand Up @@ -605,6 +752,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
work =
initWork(device, rank_, opType, false, profilingTitle, inputs, outputs);
if (coalescing_state_) {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
nullptr,
options_->timeout,
pgStatus_,
false);
}

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

Expand Down Expand Up @@ -638,8 +801,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});
work->blockingWait_ = blockingWait_;

work->numelIn_ = 0;
work->numelOut_ = 0;
for (const auto& input : inputs) {
work->numelIn_ += input.numel();
}
for (const auto& output : outputs) {
work->numelOut_ += output.numel();
}
setStartedPgStatus(work);

return asyncOp ? work : nullptr;
}

Expand Down Expand Up @@ -672,6 +849,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
}
}

op_id_++;
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);

if (coalescing_state_ & CoalActive) {
Expand Down Expand Up @@ -703,6 +881,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
work->outputs_->push_back(tensor);

work->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
work->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
true);

c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand All @@ -718,8 +911,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});

work->numelIn_ = work->numelOut_ = tensor.numel();
setStartedPgStatus(work);
return work;
} else {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
nullptr,
options_->timeout,
pgStatus_,
true);
c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand Down
Loading