Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/BuildOnLinux.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ macro(setup_common_libraries)
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()
list(APPEND TORCH_XPU_OPS_LIBRARIES torch_xpu_ops)
endmacro()
Expand Down Expand Up @@ -125,6 +126,7 @@ else()
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()

install(TARGETS torch_xpu_ops DESTINATION "${TORCH_INSTALL_LIB_DIR}")
Expand Down
21 changes: 21 additions & 0 deletions src/xccl/FlightRecorderXCCL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <ATen/xpu/XPUEvent.h>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

template <>
float getDurationFromEvent<at::xpu::XPUEvent>(
at::xpu::XPUEvent& xcclStartEvent,
at::xpu::XPUEvent& xcclEndEvent) {
TORCH_CHECK(
xcclEndEvent.query(),
"getDuration can only be called after work is succeeded.")
return xcclStartEvent.elapsed_time(xcclEndEvent);
}

template struct FlightRecorder<at::xpu::XPUEvent>;
} // namespace c10d
#endif // USE_C10D_XCCL
218 changes: 216 additions & 2 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;

namespace {

#if defined(CCL_MAJOR_VERSION) && \
Expand Down Expand Up @@ -302,6 +305,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

ProcessGroupXCCL::Options::Options()
: Backend::Options(XCCL_BACKEND_NAME) {}


static std::atomic<size_t> process_group_id = 0;

constexpr const char* MULTI_DEVICE_ERROR_MSG =
Expand Down Expand Up @@ -331,13 +338,21 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size)
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size),
store_(store),
options_(std::move(options)),
xcclCommCounter_(0),
local_id_(process_group_id++) {
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);

this->setGroupUid(options_->group_name);
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
FlightRecorderXCCL::get()->record_pg_ranks(
std::make_tuple(pg_uid_, pg_desc_), groupRanks());
init();
const std::string OFF = "OFF";
std::string torch_distributed_debug =
Expand All @@ -350,9 +365,124 @@ ProcessGroupXCCL::ProcessGroupXCCL(
<< "XCCL version: " << XcclVersion
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug;

// Heartbeat monitor thread dumps debug info on write to pipe
heartbeatMonitor_ = std::make_unique<HeartbeatMonitor>(this);
heartbeatMonitor_->start();
}

ProcessGroupXCCL::~ProcessGroupXCCL() {
heartbeatMonitor_->stop();
// Wait for all threads to finish before returning
heartbeatMonitor_->join();
}

bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupXCCL__dumpDebuggingInfo);
LOG(ERROR)
<< logPrefix()
<< "ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
<< includeStackTrace;
if (traceBufferSize_ > 0) {
// TODO: dump_xccl_trace
auto xcclDumpMap = std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>();
auto xcclTrace = FlightRecorderXCCL::get()->dump(
xcclDumpMap, true, includeStackTrace, false);
DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_);
LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to "
<< writer.getWriterTarget();
writer.write(xcclTrace);
LOG(INFO) << logPrefix() << "Flight Recorder trace successfully dumped.";
return true;
}
return false;
}

ProcessGroupXCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupXCCL* pg) {
pg_ = pg;
coordCheckIntervalMilSec_ = getCvarInt(TORCH_XCCL_COORD_CHECK_MILSEC, 1000);
LOG(INFO)
<< pg_->logPrefix() << "HeartbeatMonitor environments: "
<< "TORCH_XCCL_COOR_CHECK_MILSEC: " << coordCheckIntervalMilSec_;
}

void ProcessGroupXCCL::HeartbeatMonitor::stop() {
terminateHeartbeatMonitorThread_.store(true);
monitorWakeUpCV_.notify_one();
}

void ProcessGroupXCCL::HeartbeatMonitor::start() {
TORCH_CHECK(
!xcclHeartbeatMonitorThread_.joinable(),
"HeartbeatMonitor thread already started");
xcclHeartbeatMonitorThread_ =
std::thread(&ProcessGroupXCCL::HeartbeatMonitor::runLoop, this);
}

void ProcessGroupXCCL::HeartbeatMonitor::join() {
if (xcclHeartbeatMonitorThread_.joinable()) {
xcclHeartbeatMonitorThread_.join();
LOG(INFO) << pg_->logPrefix()
<< "ProcessGroupXCCL heart beat monitor thread joined.";
}
}

void ProcessGroupXCCL::HeartbeatMonitor::runLoop() {
c10::setThreadName("pt_xccl_heartbt");

std::optional<DumpPipe> dumpPipe = std::nullopt;
// We only need to dump once per PG, so we use local_id_ == 0 for the first PG
if (pg_->local_id_ == 0) {
// DumpPipe is one per-trainer process
dumpPipe.emplace(pg_->rank_);
while (true) {
std::unique_lock<std::mutex> lock(monitorMutex_);
if (monitorWakeUpCV_.wait_for(
lock, std::chrono::milliseconds(coordCheckIntervalMilSec_), [&] {
return terminateHeartbeatMonitorThread_.load();
})) {
return;
}
// Write to pipe files for all ranks to dump debug info
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
LOG(INFO) << pg_->logPrefix()
<< "Dump signal received through pipe, triggering FR dump.";
std::future<bool> fut = std::async(std::launch::async, [this]() {
return this->pg_->dumpDebuggingInfo();
});
}
}
}
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
if (options_->global_ranks_in_group.empty()) {
static std::vector<uint64_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}

void ProcessGroupXCCL::setStartedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastStartedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastStartedWorkName = opTypeToString(work->opType_);
pgStatus_->lastStartedNumelIn = work->numelIn_;
pgStatus_->lastStartedNumelOut = work->numelOut_;
}

void ProcessGroupXCCL::setCompletedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_);
pgStatus_->lastCompletedNumelIn = work->numelIn_;
pgStatus_->lastCompletedNumelOut = work->numelOut_;
// To avoid complexity, we're not computing duration.
FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false);
}

void ProcessGroupXCCL::setSequenceNumberForGroup() {}

Expand All @@ -377,6 +507,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
profilingTitle,
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
: std::nullopt);

r->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
r->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
isP2P);
return r;
}

Expand Down Expand Up @@ -531,6 +676,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
groupEnd();

work->xcclEndEvent_->record(stream);
setStartedPgStatus(work);

coalescing_state_ = 0;
coalescedComm_ = nullptr;
Expand Down Expand Up @@ -563,6 +709,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
if ((coalescing_state_ & CoalColl) == 0) {
seqCollective_++;
}
op_id_++;
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = device;
Expand Down Expand Up @@ -605,6 +752,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
work =
initWork(device, rank_, opType, false, profilingTitle, inputs, outputs);
if (coalescing_state_) {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
nullptr,
options_->timeout,
pgStatus_,
false);
}

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

Expand Down Expand Up @@ -638,8 +801,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});
work->blockingWait_ = blockingWait_;

work->numelIn_ = 0;
work->numelOut_ = 0;
for (const auto& input : inputs) {
work->numelIn_ += input.numel();
}
for (const auto& output : outputs) {
work->numelOut_ += output.numel();
}
setStartedPgStatus(work);

return asyncOp ? work : nullptr;
}

Expand Down Expand Up @@ -672,6 +849,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
}
}

op_id_++;
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);

if (coalescing_state_ & CoalActive) {
Expand Down Expand Up @@ -703,6 +881,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
work->outputs_->push_back(tensor);

work->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
work->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
true);

c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand All @@ -718,8 +911,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});

work->numelIn_ = work->numelOut_ = tensor.numel();
setStartedPgStatus(work);
return work;
} else {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
nullptr,
options_->timeout,
pgStatus_,
true);
c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand Down
Loading
Loading