1
1
#ifdef USE_C10D_XCCL
2
2
3
3
#include < torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
4
+ #include < torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
4
5
#include < xccl/NanCheck_XPU.hpp>
5
6
#include < xccl/ProcessGroupXCCL.hpp>
6
7
7
8
namespace c10d {
8
9
10
+ using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;
11
+
9
12
namespace {
10
13
11
14
#if defined(CCL_MAJOR_VERSION) && \
@@ -200,6 +203,17 @@ void syncStream(
200
203
201
204
} // namespace
202
205
206
+ std::string dump_xccl_trace (
207
+ bool includeCollectives,
208
+ bool includeStackTraces,
209
+ bool onlyActive) {
210
+ auto xcclDumpMap = std::unordered_map<
211
+ std::string,
212
+ std::unordered_map<std::string, std::string>>();
213
+ return FlightRecorderXCCL::get ()->dump (
214
+ xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
215
+ }
216
+
203
217
constexpr int64_t kSynchronizeBusyWaitMillis = 10 ;
204
218
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0 ;
205
219
@@ -303,6 +317,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
303
317
return true ;
304
318
}
305
319
320
+ ProcessGroupXCCL::Options::Options ()
321
+ : Backend::Options(XCCL_BACKEND_NAME) {}
322
+
323
+
306
324
static std::atomic<size_t > process_group_id = 0 ;
307
325
308
326
constexpr const char * MULTI_DEVICE_ERROR_MSG =
@@ -332,19 +350,28 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
332
350
ProcessGroupXCCL::ProcessGroupXCCL (
333
351
const c10::intrusive_ptr<Store>& store,
334
352
int rank,
335
- int size)
353
+ int size,
354
+ c10::intrusive_ptr<Options> options)
336
355
: Backend(rank, size),
337
356
store_(store),
357
+ options_(std::move(options)),
338
358
xcclCommCounter_(0 ),
339
359
local_id_(process_group_id++) {
340
360
logPrefix_ = createLogPrefix ();
341
361
blockingWait_ = getCvarBool (TORCH_XCCL_BLOCKING_WAIT, false );
362
+ traceBufferSize_ = getCvarInt ({" TORCH_FR_BUFFER_SIZE" }, 2000 );
363
+
364
+ this ->setGroupUid (options_->group_name );
365
+ // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
366
+ const auto XcclVersion = getXcclVersion ();
367
+ FlightRecorderXCCL::get ()->record_pg_ranks (
368
+ std::make_tuple (pg_uid_, pg_desc_), groupRanks ());
369
+ FlightRecorderXCCL::get ()->record_accelerator_version (XcclVersion);
342
370
enableNanCheck_ = getCvarBool (TORCH_XCCL_NAN_CHECK, false );
343
371
init ();
344
372
const std::string OFF = " OFF" ;
345
373
std::string torch_distributed_debug =
346
374
getCvarString ({" TORCH_DISTRIBUTED_DEBUG" }, OFF.c_str ());
347
- const auto XcclVersion = getXcclVersion ();
348
375
LOG (INFO) << logPrefix () << " ProcessGroupXCCL initialization options: "
349
376
<< " size: " << size << " , global rank: " << rank_;
350
377
@@ -353,9 +380,63 @@ ProcessGroupXCCL::ProcessGroupXCCL(
353
380
<< " , TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
354
381
<< " , TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
355
382
<< " , TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
383
+
384
+ // Heartbeat monitor thread dumps debug info on write to pipe
385
+ heartbeatMonitor_ = std::make_unique<HeartbeatMonitorXCCL>(this );
386
+ heartbeatMonitor_->start ();
387
+ }
388
+
389
+ ProcessGroupXCCL::~ProcessGroupXCCL () {
390
+ heartbeatMonitor_->stop ();
391
+ // Wait for all threads to finish before returning
392
+ heartbeatMonitor_->join ();
356
393
}
357
394
358
- ProcessGroupXCCL::~ProcessGroupXCCL () = default ;
395
+ bool ProcessGroupXCCL::dumpDebuggingInfo (bool includeStackTrace /* =true*/ ) {
396
+ STATIC_SCOPED_WAIT_COUNTER (pytorch.ProcessGroupXCCL__dumpDebuggingInfo );
397
+ LOG (ERROR)
398
+ << logPrefix ()
399
+ << " ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
400
+ << includeStackTrace;
401
+ if (traceBufferSize_ > 0 ) {
402
+ // TODO: dump_xccl_trace
403
+ auto xcclTrace = dump_xccl_trace (true , includeStackTrace, false );
404
+ DebugInfoWriter& writer = DebugInfoWriter::getWriter (rank_);
405
+ LOG (INFO) << logPrefix () << " ProcessGroupXCCL dumping xccl trace to "
406
+ << writer.getWriterTarget ();
407
+ writer.write (xcclTrace);
408
+ LOG (INFO) << logPrefix () << " Flight Recorder trace successfully dumped." ;
409
+ return true ;
410
+ }
411
+ return false ;
412
+ }
413
+
414
+ const std::vector<uint64_t >& ProcessGroupXCCL::groupRanks () const {
415
+ if (options_->global_ranks_in_group .empty () && local_id_ == 0 ) {
416
+ static std::vector<uint64_t > globalRanks (size_);
417
+ std::iota (globalRanks.begin (), globalRanks.end (), 0 );
418
+ return globalRanks;
419
+ }
420
+ return options_->global_ranks_in_group ;
421
+ }
422
+
423
+ void ProcessGroupXCCL::setEnqueuedPgStatus (
424
+ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
425
+ pgStatus_->lastEnqueuedSeq = static_cast <int64_t >(work->getSequencenumber ());
426
+ pgStatus_->lastEnqueuedWorkName = opTypeToString (work->opType_ );
427
+ pgStatus_->lastEnqueuedNumelIn = work->numelIn_ ;
428
+ pgStatus_->lastEnqueuedNumelOut = work->numelOut_ ;
429
+ }
430
+
431
+ void ProcessGroupXCCL::setCompletedPgStatus (
432
+ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
433
+ pgStatus_->lastCompletedSeq = static_cast <int64_t >(work->getSequencenumber ());
434
+ pgStatus_->lastCompletedWorkName = opTypeToString (work->opType_ );
435
+ pgStatus_->lastCompletedNumelIn = work->numelIn_ ;
436
+ pgStatus_->lastCompletedNumelOut = work->numelOut_ ;
437
+ // To avoid complexity, we're not computing duration.
438
+ FlightRecorderXCCL::get ()->retire_id (work->trace_id_ , /* compute_duration*/ false );
439
+ }
359
440
360
441
void ProcessGroupXCCL::setSequenceNumberForGroup () {}
361
442
@@ -384,6 +465,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
384
465
profilingTitle,
385
466
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
386
467
: std::nullopt);
468
+
469
+ r->trace_id_ = FlightRecorderXCCL::get ()->record (
470
+ local_id_,
471
+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
472
+ seqCollective_,
473
+ seqP2P_,
474
+ op_id_,
475
+ profilingTitle ? profilingTitle : " " ,
476
+ inputs,
477
+ outputs,
478
+ nullptr ,
479
+ r->xcclEndEvent_ .get (),
480
+ options_->timeout ,
481
+ pgStatus_,
482
+ isP2P);
387
483
return r;
388
484
}
389
485
@@ -538,6 +634,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
538
634
groupEnd ();
539
635
540
636
work->xcclEndEvent_ ->record (stream);
637
+ setEnqueuedPgStatus (work);
541
638
542
639
coalescing_state_ = 0 ;
543
640
coalescedComm_ = nullptr ;
@@ -572,6 +669,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
572
669
if ((coalescing_state_ & CoalColl) == 0 ) {
573
670
seqCollective_++;
574
671
}
672
+ op_id_++;
575
673
coalescing_state_ |= CoalColl;
576
674
if (coalescedDevice_.index () < 0 ) {
577
675
coalescedDevice_ = device;
@@ -614,6 +712,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
614
712
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
615
713
work =
616
714
initWork (device, rank_, opType, false , profilingTitle, inputs, outputs);
715
+ if (coalescing_state_) {
716
+ FlightRecorderXCCL::get ()->record (
717
+ local_id_,
718
+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
719
+ seqCollective_,
720
+ seqP2P_,
721
+ op_id_,
722
+ profilingTitle ? profilingTitle : " " ,
723
+ inputs,
724
+ outputs,
725
+ nullptr ,
726
+ nullptr ,
727
+ options_->timeout ,
728
+ pgStatus_,
729
+ false );
730
+ }
617
731
618
732
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
619
733
@@ -653,8 +767,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
653
767
work->future_ = c10::make_intrusive<at::ivalue::Future>(
654
768
c10::ListType::create (c10::TensorType::get ()), devices);
655
769
work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
770
+ work->future_ ->addCallback (
771
+ [this , work](at::ivalue::Future&) {
772
+ this ->setCompletedPgStatus (work);
773
+ });
656
774
work->blockingWait_ = blockingWait_;
657
775
776
+ work->numelIn_ = 0 ;
777
+ work->numelOut_ = 0 ;
778
+ for (const auto & input : inputs) {
779
+ work->numelIn_ += input.numel ();
780
+ }
781
+ for (const auto & output : outputs) {
782
+ work->numelOut_ += output.numel ();
783
+ }
784
+ setEnqueuedPgStatus (work);
785
+
658
786
return asyncOp ? work : nullptr ;
659
787
}
660
788
@@ -687,6 +815,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
687
815
}
688
816
}
689
817
818
+ op_id_++;
690
819
auto comm = getXCCLComm (key, device, opType, p2pRank, isSendRecvSelf);
691
820
692
821
if (coalescing_state_ & CoalActive) {
@@ -722,6 +851,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
722
851
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
723
852
work->outputs_ ->push_back (tensor);
724
853
854
+ work->trace_id_ = FlightRecorderXCCL::get ()->record (
855
+ local_id_,
856
+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
857
+ seqCollective_,
858
+ seqP2P_,
859
+ op_id_,
860
+ profilingTitle,
861
+ {tensor},
862
+ {tensor},
863
+ nullptr ,
864
+ work->xcclEndEvent_ .get (),
865
+ options_->timeout ,
866
+ pgStatus_,
867
+ true );
868
+
725
869
c10::OptionalDeviceGuard gpuGuard (device);
726
870
727
871
c10::xpu::XPUCachingAllocator::recordStream (
@@ -737,8 +881,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
737
881
work->future_ = c10::make_intrusive<at::ivalue::Future>(
738
882
c10::ListType::create (c10::TensorType::get ()), devices);
739
883
work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
884
+ work->future_ ->addCallback (
885
+ [this , work](at::ivalue::Future&) {
886
+ this ->setCompletedPgStatus (work);
887
+ });
888
+
889
+ work->numelIn_ = work->numelOut_ = tensor.numel ();
890
+ setEnqueuedPgStatus (work);
740
891
return work;
741
892
} else {
893
+ FlightRecorderXCCL::get ()->record (
894
+ local_id_,
895
+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
896
+ seqCollective_,
897
+ seqP2P_,
898
+ op_id_,
899
+ profilingTitle,
900
+ {tensor},
901
+ {tensor},
902
+ nullptr ,
903
+ nullptr ,
904
+ options_->timeout ,
905
+ pgStatus_,
906
+ true );
742
907
c10::OptionalDeviceGuard gpuGuard (device);
743
908
744
909
c10::xpu::XPUCachingAllocator::recordStream (
@@ -2135,6 +2300,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
2135
2300
" xccl:all_to_all" );
2136
2301
}
2137
2302
2303
+ std::string getXcclVersion () {
2304
+ auto xccl_version = ccl::get_library_version ();
2305
+ std::string versionString = std::to_string (xccl_version.major ) + " ." +
2306
+ std::to_string (xccl_version.minor ) + " ." +
2307
+ std::to_string (xccl_version.update );
2308
+ return versionString;
2309
+ }
2310
+
2138
2311
} // namespace c10d
2139
2312
2140
2313
#endif // USE_C10D_XCCL
0 commit comments