diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index aee7118f01..87b0fe3454 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -16,6 +16,7 @@ macro(setup_common_libraries) if(USE_C10D_XCCL) target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) + target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) endif() list(APPEND TORCH_XPU_OPS_LIBRARIES torch_xpu_ops) endmacro() @@ -125,6 +126,7 @@ else() if(USE_C10D_XCCL) target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) + target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) endif() install(TARGETS torch_xpu_ops DESTINATION "${TORCH_INSTALL_LIB_DIR}") diff --git a/src/xccl/FlightRecorderXCCL.cpp b/src/xccl/FlightRecorderXCCL.cpp new file mode 100644 index 0000000000..29fccd9907 --- /dev/null +++ b/src/xccl/FlightRecorderXCCL.cpp @@ -0,0 +1,21 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include + +namespace c10d { + +template <> +float getDurationFromEvent( + at::xpu::XPUEvent& xcclStartEvent, + at::xpu::XPUEvent& xcclEndEvent) { + TORCH_CHECK( + xcclEndEvent.query(), + "getDuration can only be called after work is succeeded.") + return xcclStartEvent.elapsed_time(xcclEndEvent); +} + +template struct FlightRecorder; +} // namespace c10d +#endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 9ace6ccb90..240a54cea4 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -1,11 +1,14 @@ #ifdef USE_C10D_XCCL #include +#include #include #include namespace c10d { +using FlightRecorderXCCL = FlightRecorder; + namespace { #if defined(CCL_MAJOR_VERSION) && \ @@ -200,6 +203,17 @@ void syncStream( } // namespace +std::string dump_xccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto xcclDumpMap = std::unordered_map< + std::string, + std::unordered_map>(); + return FlightRecorderXCCL::get()->dump( + xcclDumpMap, includeCollectives, includeStackTraces, onlyActive); +} + constexpr int64_t kSynchronizeBusyWaitMillis = 10; thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; @@ -303,6 +317,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } +ProcessGroupXCCL::Options::Options() + : Backend::Options(XCCL_BACKEND_NAME) {} + + static std::atomic process_group_id = 0; constexpr const char* MULTI_DEVICE_ERROR_MSG = @@ -332,19 +350,28 @@ const std::string& ProcessGroupXCCL::logPrefix() const { ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, - int size) + int size, + c10::intrusive_ptr options) : Backend(rank, size), store_(store), + options_(std::move(options)), xcclCommCounter_(0), local_id_(process_group_id++) { logPrefix_ = createLogPrefix(); blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000); + + this->setGroupUid(options_->group_name); + // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here. + const auto XcclVersion = getXcclVersion(); + FlightRecorderXCCL::get()->record_pg_ranks( + std::make_tuple(pg_uid_, pg_desc_), groupRanks()); + FlightRecorderXCCL::get()->record_accelerator_version(XcclVersion); enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false); init(); const std::string OFF = "OFF"; std::string torch_distributed_debug = getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); - const auto XcclVersion = getXcclVersion(); LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: " << "size: " << size << ", global rank: " << rank_; @@ -353,9 +380,63 @@ ProcessGroupXCCL::ProcessGroupXCCL( << ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_ << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug << ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_; + + // Heartbeat monitor thread dumps debug info on write to pipe + heartbeatMonitor_ = std::make_unique(this); + heartbeatMonitor_->start(); +} + +ProcessGroupXCCL::~ProcessGroupXCCL() { + heartbeatMonitor_->stop(); + // Wait for all threads to finish before returning + heartbeatMonitor_->join(); } -ProcessGroupXCCL::~ProcessGroupXCCL() = default; +bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) { + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupXCCL__dumpDebuggingInfo); + LOG(ERROR) + << logPrefix() + << "ProcessGroupXCCL preparing to dump debug info. Include stack trace: " + << includeStackTrace; + if (traceBufferSize_ > 0) { + // TODO: dump_xccl_trace + auto xcclTrace = dump_xccl_trace(true, includeStackTrace, false); + DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_); + LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to " + << writer.getWriterTarget(); + writer.write(xcclTrace); + LOG(INFO) << logPrefix() << "Flight Recorder trace successfully dumped."; + return true; + } + return false; +} + +const std::vector& ProcessGroupXCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + +void ProcessGroupXCCL::setEnqueuedPgStatus( + c10::intrusive_ptr work) { + pgStatus_->lastEnqueuedSeq = static_cast(work->getSequencenumber()); + pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_->lastEnqueuedNumelIn = work->numelIn_; + pgStatus_->lastEnqueuedNumelOut = work->numelOut_; +} + +void ProcessGroupXCCL::setCompletedPgStatus( + c10::intrusive_ptr work) { + pgStatus_->lastCompletedSeq = static_cast(work->getSequencenumber()); + pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_); + pgStatus_->lastCompletedNumelIn = work->numelIn_; + pgStatus_->lastCompletedNumelOut = work->numelOut_; + // To avoid complexity, we're not computing duration. + FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false); +} void ProcessGroupXCCL::setSequenceNumberForGroup() {} @@ -384,6 +465,21 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : std::nullopt); + + r->trace_id_ = FlightRecorderXCCL::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + seqCollective_, + seqP2P_, + op_id_, + profilingTitle ? profilingTitle : "", + inputs, + outputs, + nullptr, + r->xcclEndEvent_.get(), + options_->timeout, + pgStatus_, + isP2P); return r; } @@ -538,6 +634,7 @@ c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { groupEnd(); work->xcclEndEvent_->record(stream); + setEnqueuedPgStatus(work); coalescing_state_ = 0; coalescedComm_ = nullptr; @@ -572,6 +669,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( if ((coalescing_state_ & CoalColl) == 0) { seqCollective_++; } + op_id_++; coalescing_state_ |= CoalColl; if (coalescedDevice_.index() < 0) { coalescedDevice_ = device; @@ -614,6 +712,22 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::intrusive_ptr work; work = initWork(device, rank_, opType, false, profilingTitle, inputs, outputs); + if (coalescing_state_) { + FlightRecorderXCCL::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + seqCollective_, + seqP2P_, + op_id_, + profilingTitle ? profilingTitle : "", + inputs, + outputs, + nullptr, + nullptr, + options_->timeout, + pgStatus_, + false); + } work->outputs_ = std::make_shared>(outputs); @@ -653,8 +767,22 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); + work->future_->addCallback( + [this, work](at::ivalue::Future&) { + this->setCompletedPgStatus(work); + }); work->blockingWait_ = blockingWait_; + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } + setEnqueuedPgStatus(work); + return asyncOp ? work : nullptr; } @@ -687,6 +815,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } } + op_id_++; auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); if (coalescing_state_ & CoalActive) { @@ -722,6 +851,21 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); + work->trace_id_ = FlightRecorderXCCL::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + nullptr, + work->xcclEndEvent_.get(), + options_->timeout, + pgStatus_, + true); + c10::OptionalDeviceGuard gpuGuard(device); c10::xpu::XPUCachingAllocator::recordStream( @@ -737,8 +881,29 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); + work->future_->addCallback( + [this, work](at::ivalue::Future&) { + this->setCompletedPgStatus(work); + }); + + work->numelIn_ = work->numelOut_ = tensor.numel(); + setEnqueuedPgStatus(work); return work; } else { + FlightRecorderXCCL::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + nullptr, + nullptr, + options_->timeout, + pgStatus_, + true); c10::OptionalDeviceGuard gpuGuard(device); c10::xpu::XPUCachingAllocator::recordStream( @@ -2135,6 +2300,14 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( "xccl:all_to_all"); } +std::string getXcclVersion() { + auto xccl_version = ccl::get_library_version(); + std::string versionString = std::to_string(xccl_version.major) + "." + + std::to_string(xccl_version.minor) + "." + + std::to_string(xccl_version.update); + return versionString; +} + } // namespace c10d #endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 919f3809ee..42c1997356 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -19,13 +19,19 @@ #include #include #include +#include #include +#include namespace c10d { static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; +static std::vector TORCH_XCCL_COORD_CHECK_MILSEC = { + "TORCH_XCCL_COORD_CHECK_MILSEC", + "XCCL_COORD_CHECK_MILSEC"}; + using xcclComm_t = ccl::communicator; static std::vector TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; @@ -100,6 +106,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::chrono::time_point workStartTime_; uint64_t seq_; bool isP2P_; + std::optional trace_id_; + size_t numelIn_ = -1; + size_t numelOut_ = -1; private: std::shared_ptr> outputs_; @@ -108,7 +117,22 @@ class TORCH_API ProcessGroupXCCL : public Backend { friend class ProcessGroupXCCL; }; - ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + struct Options : public Backend::Options { + explicit Options(); + + static c10::intrusive_ptr create() { + return c10::make_intrusive(); + } + + std::vector global_ranks_in_group; + std::string group_name; + }; + + ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + c10::intrusive_ptr options = Options::create()); C10_DEPRECATED ProcessGroupXCCL( const c10::intrusive_ptr& store, @@ -389,6 +413,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::DeviceIndex guessDeviceId() const; + const std::vector& groupRanks() const; + void setEnqueuedPgStatus(c10::intrusive_ptr work); + void setCompletedPgStatus(c10::intrusive_ptr work); + bool dumpDebuggingInfo(bool includeStackTrace = true); + protected: std::unordered_map> xcclStreamsMap_; @@ -407,10 +436,18 @@ class TORCH_API ProcessGroupXCCL : public Backend { static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; uint64_t seqP2P_{0}; + uint64_t op_id_{0}; size_t local_id_; std::string logPrefix_; + const c10::intrusive_ptr options_; + std::shared_ptr pgStatus_ = + std::make_shared(); + std::unique_ptr heartbeatMonitor_; + int traceBufferSize_; bool enableNanCheck_; + friend class HeartbeatMonitorXCCL; + private: std::mutex kvs_mutex; @@ -448,18 +485,18 @@ class TORCH_API ProcessGroupXCCL : public Backend { return kvs; } }; + +// Dumps the comm traces and additional information about the ProcessGroup. +TORCH_API std::string dump_xccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); + +TORCH_API std::string getXcclVersion(); } // namespace c10d namespace { -inline std::string getXcclVersion() { - auto xccl_version = ccl::get_library_version(); - std::string versionString = std::to_string(xccl_version.major) + "." + - std::to_string(xccl_version.minor) + "." + - std::to_string(xccl_version.update); - return versionString; -} - inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: diff --git a/src/xccl/ProcessGroupXCCLMonitor.cpp b/src/xccl/ProcessGroupXCCLMonitor.cpp new file mode 100644 index 0000000000..cefc6d4022 --- /dev/null +++ b/src/xccl/ProcessGroupXCCLMonitor.cpp @@ -0,0 +1,66 @@ +#ifdef USE_C10D_XCCL + +#include +#include +namespace c10d { + +HeartbeatMonitorXCCL::HeartbeatMonitorXCCL(ProcessGroupXCCL* pg) { + pg_ = pg; + coordCheckIntervalMilSec_ = getCvarInt(TORCH_XCCL_COORD_CHECK_MILSEC, 1000); + LOG(INFO) + << pg_->logPrefix() << "HeartbeatMonitor environments: " + << "TORCH_XCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_; +} + +void HeartbeatMonitorXCCL::stop() { + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); +} + +void HeartbeatMonitorXCCL::start() { + TORCH_CHECK( + !xcclHeartbeatMonitorThread_.joinable(), + "HeartbeatMonitor thread already started"); + xcclHeartbeatMonitorThread_ = + std::thread(&HeartbeatMonitorXCCL::runLoop, this); +} + +void HeartbeatMonitorXCCL::join() { + if (xcclHeartbeatMonitorThread_.joinable()) { + xcclHeartbeatMonitorThread_.join(); + LOG(INFO) << pg_->logPrefix() + << "ProcessGroupXCCL heart beat monitor thread joined."; + } +} + +void HeartbeatMonitorXCCL::runLoop() { + c10::setThreadName("pt_xccl_heartbt"); + + std::optional dumpPipe = std::nullopt; + // We only need to dump once per PG, so we use local_id_ == 0 for the first PG + if (pg_->local_id_ == 0) { + // DumpPipe is one per-trainer process + dumpPipe.emplace(pg_->getRank()); + while (true) { + std::unique_lock lock(monitorMutex_); + if (monitorWakeUpCV_.wait_for( + lock, std::chrono::milliseconds(coordCheckIntervalMilSec_), [&] { + return terminateHeartbeatMonitorThread_.load(); + })) { + return; + } + // Write to pipe files for all ranks to dump debug info + if (dumpPipe.has_value() && dumpPipe->shouldDump()) { + LOG(INFO) << pg_->logPrefix() + << "Dump signal received through pipe, triggering FR dump."; + std::future fut = std::async(std::launch::async, [this]() { + return this->pg_->dumpDebuggingInfo(); + }); + } + } + } +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCLMonitor.hpp b/src/xccl/ProcessGroupXCCLMonitor.hpp new file mode 100644 index 0000000000..8924c4e43e --- /dev/null +++ b/src/xccl/ProcessGroupXCCLMonitor.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef USE_C10D_XCCL +namespace c10d { + +// This definition will later be moved to a common header for ProcessGroups NCCL/Gloo/XCCL +#if defined(__linux__) +struct DumpPipe { + DumpPipe(int rank) { + std::string fileStem = + getCvarString({"TORCH_FR_DEBUG_INFO_PIPE_FILE"}, ""); + if (fileStem.empty() || + getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 0) <= 0) { + return; + } + TORCH_CHECK(!fileStem.empty(), "TORCH_FR_DEBUG_INFO_PIPE_FILE is empty"); + std::string filename = c10::str(fileStem, rank, ".pipe"); + TORCH_CHECK( + unlink(filename.c_str()) != -1 || errno == ENOENT, + "Error removing existing named pipe ", + filename, + ", Error: ", + std::strerror(errno)); + TORCH_CHECK( + mkfifo(filename.c_str(), 0666) != -1, + "Error creating named pipe ", + filename, + ", Error: ", + std::strerror(errno)); + fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); + LOG(INFO) << "Pipe file " << filename + << " has been opened, write to it to trigger ProcessGroup Debug Dump."; + TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename); + } + bool shouldDump() { + if (fd_ == -1) { + return false; + } + // NOLINTNEXTLINE(*array*) + char buf[128]{}; + // non-blocking from O_NONBLOCK above. + // Ignore EINTR because we already will poll this + // again later. + ssize_t bytesRead = read(fd_, &buf, 128); + return bytesRead > 0; + } + ~DumpPipe() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + int fd_ = -1; +}; +#else +struct DumpPipe { + DumpPipe(int rank) {} + bool shouldDump() { + return false; + } +}; +#endif + +class ProcessGroupXCCL; +class HeartbeatMonitorXCCL { + public: + HeartbeatMonitorXCCL(ProcessGroupXCCL* pg); + virtual ~HeartbeatMonitorXCCL() = default; + + std::string getXCCLTimeoutErrorMsg(const std::string& extraMsg); + void start(); + void join(); + virtual void runLoop(); + void stop(); + + protected: + ProcessGroupXCCL* pg_; + + private: + int coordCheckIntervalMilSec_; + std::condition_variable monitorWakeUpCV_; + std::mutex monitorMutex_; + std::thread xcclHeartbeatMonitorThread_; + std::atomic terminateHeartbeatMonitorThread_{false}; +}; +} +#endif // USE_C10D_XCCL diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index 916524073c..e143be041e 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -1,16 +1,21 @@ # Owner(s): ["oncall: distributed"] +import json import math import os +import pickle import random import signal import sys +import tempfile +import threading import time -from datetime import timedelta +from datetime import datetime, timedelta from enum import auto, Enum from unittest import mock import torch +import torch._C._distributed_c10d import torch.distributed as c10d if not c10d.is_available() or not c10d.is_xccl_available(): @@ -627,6 +632,316 @@ def test_all_gather_into_tensor(self): ) +class XCCLTraceTestBase(MultiProcessTestCase): + def setUp(self): + super().setUp() + os.environ["TORCH_FR_BUFFER_SIZE"] = "1000" + self.tempdir = tempfile.TemporaryDirectory() + os.environ["TORCH_FR_DUMP_TEMP_FILE"] = self._trace_basename() + os.environ["TORCH_FR_DEBUG_INFO_PIPE_FILE"] = self._trace_basename() + self._spawn_processes() + + @classmethod + def _run( + cls, + parent_conn, + rank: int, + test_name: str, + file_name: str, + parent_pipe, + **kwargs, + ) -> None: + cls.parent = parent_conn + super()._run(rank, test_name, file_name, parent_pipe) + + @property + def local_device(self): + return torch.device("xpu", self.rank_to_GPU[self.rank][0]) + + def _join_processes(self, fn): + # We need to patch sys.exit() as skip_if will use sys.exit() and + # the exit code from the this process will not be caught. + with mock.patch("sys.exit"): + fn() + super()._join_processes(fn) + + def _spawn_processes(self) -> None: + proc = torch.multiprocessing.get_context("spawn").Process + self.children_pipes = [] + parent_pipes = [] + for _ in range(self.world_size): + parent_conn, child_conn = torch.multiprocessing.Pipe() + self.children_pipes.append(child_conn) + parent_pipes.append(parent_conn) + piter = iter(parent_pipes) + + def wrap(*positional, args, **kwargs): + args = (next(piter), *args) + return proc(*positional, args=args, **kwargs) + + self._start_processes(wrap) + + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + def _trace_basename(self): + # we pass the base to the env, and the dump util will append rank + return os.path.join(self.tempdir.name, "trace_") + + def _trace_name(self, rank): + return self._trace_basename() + str(rank) + + def started_or_scheduled(self, timing_enabled=False): + return "started" if timing_enabled else "scheduled" + + +class XCCLTraceTest(XCCLTraceTestBase): + def _verify_trace(self, t, include_collectives, is_json, timing_enabled=False): + ver = t["version"] + self.assertEqual(ver, "2.9") + xccl_version = t["xccl_version"] + torch_xccl_version = torch._C._distributed_c10d.get_xccl_version() + self.assertEqual(xccl_version, torch_xccl_version) + pg_config = t["pg_config"] + self.assertEqual(len(pg_config), 1) + default_pg_info = pg_config["0"] + self.assertIn("name", default_pg_info) + self.assertIn("desc", default_pg_info) + self.assertIn("ranks", default_pg_info) + pg_status = t["pg_status"] + self.assertEqual(len(pg_status), 1) + self.assertEqual(str(pg_status["0"]["last_enqueued_collective"]), "2") + self.assertEqual(str(pg_status["0"]["last_completed_collective"]), "2") + self.assertEqual( + str(pg_status["0"]["last_started_collective"]), + "2" if timing_enabled else "-1", + ) + global_ranks = pg_config["0"]["ranks"] + self.assertEqual(len(json.loads(global_ranks)), self.world_size) + if include_collectives: + self.assertEqual(len(t["entries"]), 2) + t = t["entries"] + last = t[-1] + self.assertEqual(last["thread_id"], str(threading.current_thread().ident)) + self.assertEqual(last["thread_name"], "fr_test_thread") + self.assertEqual(last["process_group"], ("0", "default_pg")) + self.assertEqual(last["state"], "completed") + s = last["time_discovered_started_ns"] + f = last["time_discovered_completed_ns"] + self.assertEqual(last["record_id"], 1) + self.assertIsNotNone(f) + if timing_enabled: + self.assertIsNotNone(s) + self.assertTrue(s <= f) + # we don't collect stack traces in JSON at the moment + if not is_json: + self.assertIn("test_c10d_xccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 2) + self.assertEqual(last["timeout_ms"], 600000) + now = datetime.now() + event_created_time = datetime.fromtimestamp( + last["time_created_ns"] / 1000000000 + ) + before_test = now - timedelta(minutes=1) + self.assertTrue(before_test < event_created_time < now) + if timing_enabled: + # very loose bounds, measured 0.036 ms on devgpu + self.assertTrue(0 < last["duration_ms"] < 100) + else: + self.assertTrue("duration_ms" not in last) + else: + self.assertTrue("entries" not in t) + + def load_libpthread_or_libc(self): + import ctypes.util + + for base in ("pthread", "c"): + path = ctypes.util.find_library(base) + if path: + try: + return ctypes.CDLL(path) + except OSError: + continue + raise RuntimeError("Could not load pthread or libc") + + # Directly set thread name using threading.current_thread().name does not work + # because we use pthread_getname_np to get the thread’s OS-level name in C++ + def set_thread_name(self, name): + import ctypes + + lib = self.load_libpthread_or_libc() + pthread_self = lib.pthread_self + pthread_self.restype = ctypes.c_void_p + pthread_setname_np = lib.pthread_setname_np + pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + + # Get current pthread handle + tid = pthread_self() + + # Set name + pthread_setname_np(tid, name.encode()) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + @parametrize("include_collectives", [True, False]) + def test_short_pickle(self, include_collectives, timing_enabled=False): + if self.rank == self.MAIN_PROCESS_RANK: + return + pg = self._create_process_group_xccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + for _ in range(2): + f = pg.allreduce(a) + f.wait() + torch.xpu.synchronize(device=device) + # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api + time.sleep(1) + t = pickle.loads( + torch._C._distributed_c10d._dump_xccl_trace( + includeCollectives=include_collectives + ) + ) + self._verify_trace( + t, + include_collectives=include_collectives, + is_json=True, + timing_enabled=timing_enabled, + ) + dist.destroy_process_group() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_dump_pipe(self): + def open_file_with_timeout(file_path, mode, timeout=1.0): + start_time = time.time() + while time.time() - start_time < timeout: + if os.path.exists(file_path): + return open(file_path, mode) + time.sleep(0.1) + raise FileNotFoundError + + if self.rank == self.MAIN_PROCESS_RANK: + for c in self.children_pipes: + self.assertEqual(c.recv(), "next") + + dump_file = self._trace_name(rank=0) + pipe_file = dump_file + ".pipe" + with open_file_with_timeout(pipe_file, "w") as f: + f.write("1\n") + with open_file_with_timeout(dump_file, "rb", timeout=10.0) as f: + self.assertTrue("all_reduce" in str(pickle.load(f))) + + for c in self.children_pipes: + c.send("next") + return + + pg = self._create_process_group_xccl() + device = self.local_device + a = torch.full((3, 4), float(self.rank), device=device) + for _ in range(2): + f = pg.allreduce(a) + f.wait() + torch.xpu.synchronize(device=device) + self.parent.send("next") + self.parent.recv() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_long(self): + os.environ["TORCH_FR_BUFFER_SIZE"] = "10" + if self.rank == self.MAIN_PROCESS_RANK: + return + pg = self._create_process_group_xccl() + device = self.local_device + a = torch.full((3, 4), float(self.rank), device=device) + for _ in range(2): + # test some other primitives to make sure + # their strings are valid + xs = [torch.ones(3, 4, device=device)] + pg.broadcast(xs).wait() + pg.allreduce(xs).wait() + pg.reduce(xs).wait() + ys = [[torch.empty(3, 4, device=device) for _ in range(self.world_size)]] + pg.allgather(ys, xs).wait() + pg.reduce_scatter(xs, ys).wait() + f = pg.allreduce(a) + f.wait() + torch.xpu.synchronize(device=device) + t = pickle.loads(torch._C._distributed_c10d._dump_xccl_trace()) + t = t["entries"] + self.assertEqual(len(t), 10) + first = t[0] + last = t[-1] + self.assertEqual(last["profiling_name"], "xccl:all_reduce") + self.assertEqual(last["state"], "completed") + self.assertIn("test_c10d_xccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["timeout_ms"], 600000) + self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) + dist.destroy_process_group() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_barrier_profiling(self): + os.environ["TORCH_FR_BUFFER_SIZE"] = "10" + if self.rank == self.MAIN_PROCESS_RANK: + return + pg = self._create_process_group_xccl() + device = self.local_device + a = torch.full((3, 4), float(self.rank), device=device) + f = pg.barrier() + f = pg.allreduce(a) + f.wait() + torch.xpu.synchronize(device=device) + t = pickle.loads(torch._C._distributed_c10d._dump_xccl_trace()) + t = t["entries"] + self.assertEqual(len(t), 2) + first = t[0] + last = t[-1] + self.assertEqual(first["profiling_name"], "xccl:all_reduce_barrier") + self.assertEqual(last["profiling_name"], "xccl:all_reduce") + dist.destroy_process_group() + + +instantiate_parametrized_tests(XCCLTraceTest) instantiate_parametrized_tests(ProcessGroupXCCLTest)