Skip to content

UT for FlightRecorderXCCL #1917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/BuildOnLinux.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ macro(setup_common_libraries)
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()
list(APPEND TORCH_XPU_OPS_LIBRARIES torch_xpu_ops)
endmacro()
Expand Down Expand Up @@ -125,6 +126,7 @@ else()
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL)
target_link_libraries(torch_xpu_ops PUBLIC torch::xccl)
target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only)
endif()

install(TARGETS torch_xpu_ops DESTINATION "${TORCH_INSTALL_LIB_DIR}")
Expand Down
21 changes: 21 additions & 0 deletions src/xccl/FlightRecorderXCCL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <ATen/xpu/XPUEvent.h>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

template <>
float getDurationFromEvent<at::xpu::XPUEvent>(
at::xpu::XPUEvent& xcclStartEvent,
at::xpu::XPUEvent& xcclEndEvent) {
TORCH_CHECK(
xcclEndEvent.query(),
"getDuration can only be called after work is succeeded.")
return xcclStartEvent.elapsed_time(xcclEndEvent);
}

template struct FlightRecorder<at::xpu::XPUEvent>;
} // namespace c10d
#endif // USE_C10D_XCCL
179 changes: 176 additions & 3 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <xccl/NanCheck_XPU.hpp>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {

using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;

namespace {

#if defined(CCL_MAJOR_VERSION) && \
Expand Down Expand Up @@ -200,6 +203,17 @@ void syncStream(

} // namespace

std::string dump_xccl_trace(
bool includeCollectives,
bool includeStackTraces,
bool onlyActive) {
auto xcclDumpMap = std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>();
return FlightRecorderXCCL::get()->dump(
xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
}

constexpr int64_t kSynchronizeBusyWaitMillis = 10;
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;

Expand Down Expand Up @@ -303,6 +317,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

ProcessGroupXCCL::Options::Options()
: Backend::Options(XCCL_BACKEND_NAME) {}


static std::atomic<size_t> process_group_id = 0;

constexpr const char* MULTI_DEVICE_ERROR_MSG =
Expand Down Expand Up @@ -332,19 +350,28 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size)
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size),
store_(store),
options_(std::move(options)),
xcclCommCounter_(0),
local_id_(process_group_id++) {
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);

this->setGroupUid(options_->group_name);
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
const auto XcclVersion = getXcclVersion();
FlightRecorderXCCL::get()->record_pg_ranks(
std::make_tuple(pg_uid_, pg_desc_), groupRanks());
FlightRecorderXCCL::get()->record_accelerator_version(XcclVersion);
enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false);
init();
const std::string OFF = "OFF";
std::string torch_distributed_debug =
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
const auto XcclVersion = getXcclVersion();
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
<< "size: " << size << ", global rank: " << rank_;

Expand All @@ -353,9 +380,63 @@ ProcessGroupXCCL::ProcessGroupXCCL(
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;

// Heartbeat monitor thread dumps debug info on write to pipe
heartbeatMonitor_ = std::make_unique<HeartbeatMonitorXCCL>(this);
heartbeatMonitor_->start();
}

ProcessGroupXCCL::~ProcessGroupXCCL() {
heartbeatMonitor_->stop();
// Wait for all threads to finish before returning
heartbeatMonitor_->join();
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupXCCL__dumpDebuggingInfo);
LOG(ERROR)
<< logPrefix()
<< "ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
<< includeStackTrace;
if (traceBufferSize_ > 0) {
// TODO: dump_xccl_trace
auto xcclTrace = dump_xccl_trace(true, includeStackTrace, false);
DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_);
LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to "
<< writer.getWriterTarget();
writer.write(xcclTrace);
LOG(INFO) << logPrefix() << "Flight Recorder trace successfully dumped.";
return true;
}
return false;
}

const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
static std::vector<uint64_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}

void ProcessGroupXCCL::setEnqueuedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
}

void ProcessGroupXCCL::setCompletedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work->getSequencenumber());
pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_);
pgStatus_->lastCompletedNumelIn = work->numelIn_;
pgStatus_->lastCompletedNumelOut = work->numelOut_;
// To avoid complexity, we're not computing duration.
FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false);
}

void ProcessGroupXCCL::setSequenceNumberForGroup() {}

Expand Down Expand Up @@ -384,6 +465,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
profilingTitle,
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
: std::nullopt);

r->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
r->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
isP2P);
return r;
}

Expand Down Expand Up @@ -538,6 +634,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
groupEnd();

work->xcclEndEvent_->record(stream);
setEnqueuedPgStatus(work);

coalescing_state_ = 0;
coalescedComm_ = nullptr;
Expand Down Expand Up @@ -572,6 +669,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
if ((coalescing_state_ & CoalColl) == 0) {
seqCollective_++;
}
op_id_++;
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = device;
Expand Down Expand Up @@ -614,6 +712,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
work =
initWork(device, rank_, opType, false, profilingTitle, inputs, outputs);
if (coalescing_state_) {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
nullptr,
options_->timeout,
pgStatus_,
false);
}

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

Expand Down Expand Up @@ -653,8 +767,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});
work->blockingWait_ = blockingWait_;

work->numelIn_ = 0;
work->numelOut_ = 0;
for (const auto& input : inputs) {
work->numelIn_ += input.numel();
}
for (const auto& output : outputs) {
work->numelOut_ += output.numel();
}
setEnqueuedPgStatus(work);

return asyncOp ? work : nullptr;
}

Expand Down Expand Up @@ -687,6 +815,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
}
}

op_id_++;
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);

if (coalescing_state_ & CoalActive) {
Expand Down Expand Up @@ -722,6 +851,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
work->outputs_->push_back(tensor);

work->trace_id_ = FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
work->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
true);

c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand All @@ -737,8 +881,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});

work->numelIn_ = work->numelOut_ = tensor.numel();
setEnqueuedPgStatus(work);
return work;
} else {
FlightRecorderXCCL::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
nullptr,
options_->timeout,
pgStatus_,
true);
c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
Expand Down Expand Up @@ -2135,6 +2300,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
"xccl:all_to_all");
}

std::string getXcclVersion() {
auto xccl_version = ccl::get_library_version();
std::string versionString = std::to_string(xccl_version.major) + "." +
std::to_string(xccl_version.minor) + "." +
std::to_string(xccl_version.update);
return versionString;
}

} // namespace c10d

#endif // USE_C10D_XCCL
Loading