Skip to content

Commit 65d4902

Browse files
Chao1Hanmengfei25
andauthored
Add an option for NAN check for xccl (#1756)
Refer from pytorch/pytorch#125726, pytorch/pytorch#135414. Add nan check for xccl. why we need to stop communication from spreading NaNs? "technically if we can be sure which rank (or, even which host) detected the first nan, then its OK to let the nan spread to some other hosts. but in practice i dont know if we have good enough way to align our logs on different hosts, so if we let the nan spread to a few other hosts we may lose track of which one was first” --------- Co-authored-by: mengfei25 <[email protected]>
1 parent ed3442d commit 65d4902

File tree

6 files changed

+368
-13
lines changed

6 files changed

+368
-13
lines changed

src/xccl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
file(GLOB xccl_h "*.hpp")
44
file(GLOB xccl_cpp "*.cpp")
5+
list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")
56

67
list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp})
8+
list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")
79

810
set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE)
11+
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)
912

1013
# Why copy the header file to the build directory?
1114
# We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29.

src/xccl/NanCheck_XPU.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/NumericUtils.h>
3+
#include <ATen/native/xpu/sycl/MemoryAccessUtils.h>
4+
#include <ATen/xpu/XPUContext.h>
5+
#include <comm/SYCLContext.h>
6+
#include <stdint.h>
7+
#include <torch/torch.h>
8+
#include <xccl/NanCheck_XPU.hpp>
9+
#include <algorithm>
10+
11+
namespace c10d {
12+
13+
using BytePack = at::native::memory::aligned_vector<uint64_t, 2>;
14+
15+
template <typename T, int EltPerPack>
16+
struct CheckBytePack {
17+
static void check(BytePack* tmp) {
18+
T* data = (T*)tmp;
19+
#pragma unroll 8
20+
for (int i = 0; i < EltPerPack; i++) {
21+
if (at::_isnan(data[i]))
22+
assert(0);
23+
}
24+
}
25+
};
26+
27+
template <typename T>
28+
struct CheckBytePack<T, /*EltPerPack*/ 2> {
29+
static void check(BytePack* tmp) {
30+
T* data = (T*)tmp;
31+
if (at::_isnan(data[0]) || at::_isnan(data[1]))
32+
assert(0);
33+
}
34+
};
35+
36+
template <typename T>
37+
struct CheckBytePack<T, /*EltPerPack*/ 4> {
38+
static void check(BytePack* tmp) {
39+
T* data = (T*)tmp;
40+
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
41+
at::_isnan(data[3]))
42+
assert(0);
43+
}
44+
};
45+
46+
template <typename T>
47+
struct CheckBytePack<T, /*EltPerPack*/ 8> {
48+
static void check(BytePack* tmp) {
49+
T* data = (T*)tmp;
50+
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
51+
at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) ||
52+
at::_isnan(data[6]) || at::_isnan(data[7])) {
53+
assert(0);
54+
}
55+
}
56+
};
57+
58+
template <typename T>
59+
struct HasNanFP8x8 {
60+
static bool check(uint64_t fp8x8) = delete;
61+
/*
62+
{
63+
// `static_assert` in template definition requires c++23 onwards.
64+
// But the error message still applies if you find yourself here.
65+
static_assert(
66+
false,
67+
"You should never call this template definition because it is empty. You "
68+
"can follow the example of Float8_e4m3fn below to implement the check for
69+
" "your new datatype."
70+
);
71+
}
72+
*/
73+
};
74+
75+
template <>
76+
struct HasNanFP8x8<c10::Float8_e4m3fn> {
77+
static bool check(uint64_t fp8x8) {
78+
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
79+
auto incremented = t + 0x0101010101010101ULL;
80+
auto overflow = incremented & 0x8080808080808080ULL;
81+
return overflow != 0;
82+
}
83+
};
84+
85+
template <>
86+
struct HasNanFP8x8<c10::Float8_e5m2> {
87+
static bool check(uint64_t fp8x8) {
88+
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
89+
auto incremented = t + 0x0303030303030303ULL;
90+
auto overflow = incremented & 0x8080808080808080ULL;
91+
return overflow != 0;
92+
}
93+
};
94+
95+
template <typename T>
96+
struct CheckBytePack<T, /*EltPerPack*/ 16> {
97+
static void check(BytePack* tmp) {
98+
if (HasNanFP8x8<T>::check(tmp->val[0]) ||
99+
HasNanFP8x8<T>::check(tmp->val[1]))
100+
assert(0);
101+
}
102+
};
103+
104+
#define UNROLL 8
105+
106+
template <typename T>
107+
void checkChunk(BytePack* ptr, int nWorkers) {
108+
BytePack tmp[UNROLL];
109+
110+
#pragma unroll 8
111+
for (int j = 0; j < UNROLL; j++) {
112+
tmp[j] = ptr[nWorkers * j];
113+
}
114+
// Then check each BytePack in the tmp buffer
115+
#pragma unroll 8
116+
for (int j = 0; j < UNROLL; j++) {
117+
CheckBytePack<T, sizeof(BytePack) / sizeof(T)>::check(tmp + j);
118+
}
119+
// Note: we separate the check from the load for efficient loading
120+
}
121+
122+
// Align address of `ptr` up, to the alignment of `T`
123+
#define ALIGN_UP(ptr, T) \
124+
(((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T))
125+
126+
template <typename T>
127+
struct checkForNaN {
128+
void operator()(sycl::nd_item<1> item) const {
129+
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T);
130+
131+
size_t offset = item.get_global_id(0);
132+
133+
// Align input address up to BytePack in case it is not
134+
T* ptrAlign = (T*)ALIGN_UP(data, BytePack);
135+
size_t preProcElts =
136+
std::min<size_t>(static_cast<size_t>(ptrAlign - data), size);
137+
138+
size_t size_left = size;
139+
140+
if (offset < preProcElts) {
141+
if (at::_isnan(data[offset]))
142+
assert(0);
143+
}
144+
size_left -= preProcElts;
145+
146+
BytePack* ptr = (BytePack*)ptrAlign;
147+
size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack);
148+
size_t loopSize = item.get_global_range(0) * UNROLL;
149+
150+
for (; offset + loopSize <= sizeInBP; offset += loopSize) {
151+
checkChunk<T>(ptr + offset, item.get_global_range(0));
152+
}
153+
154+
for (; offset < sizeInBP; offset += item.get_global_range(0)) {
155+
BytePack tmp = ptr[offset];
156+
CheckBytePack<T, EltPerPack>::check(&tmp);
157+
}
158+
159+
if (item.get_local_id(0) < size_left % EltPerPack) {
160+
T* tailPtr = (T*)(ptr + sizeInBP);
161+
if (at::_isnan(tailPtr[item.get_local_id(0)]))
162+
assert(0);
163+
}
164+
}
165+
checkForNaN(T* data, size_t size) : data(data), size(size) {}
166+
167+
private:
168+
T* data;
169+
size_t size;
170+
};
171+
172+
template <typename T>
173+
void checkfornan_impl_xpu(
174+
const at::Tensor& tensor,
175+
at::xpu::XPUStream& stream) {
176+
// skip check for non float types
177+
if (!torch::is_floating_point(tensor)) {
178+
return;
179+
}
180+
181+
int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize<checkForNaN<T>>();
182+
183+
const size_t numThreadsPerBlock =
184+
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
185+
186+
if (!(numThreadsPerBlock > 0)) {
187+
return;
188+
}
189+
190+
int64_t numBlocks =
191+
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock;
192+
auto global_range{numBlocks * numThreadsPerBlock};
193+
auto local_range{numThreadsPerBlock};
194+
195+
using Kernel = checkForNaN<T>;
196+
auto kfn = Kernel(tensor.data_ptr<T>(), tensor.numel());
197+
198+
sycl_kernel_submit(global_range, local_range, stream.queue(), kfn);
199+
}
200+
201+
// CHECK if a Tensor contains NAN in any of its element
202+
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) {
203+
AT_DISPATCH_FLOATING_TYPES_AND4(
204+
at::ScalarType::Half,
205+
at::ScalarType::BFloat16,
206+
at::ScalarType::Float8_e4m3fn,
207+
at::ScalarType::Float8_e5m2,
208+
tensor.scalar_type(),
209+
"checkForNaN_XPU",
210+
[&]() { checkfornan_impl_xpu<scalar_t>(tensor, stream); });
211+
}
212+
213+
} // namespace c10d

src/xccl/NanCheck_XPU.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#ifdef USE_C10D_XCCL
4+
5+
#include <ATen/ATen.h>
6+
#include <c10/xpu/XPUStream.h>
7+
8+
namespace c10d {
9+
10+
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream);
11+
12+
} // namespace c10d
13+
14+
#endif // USE_C10D_XCCL

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifdef USE_C10D_XCCL
22

33
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
4+
#include <xccl/NanCheck_XPU.hpp>
45
#include <xccl/ProcessGroupXCCL.hpp>
56

67
namespace c10d {
@@ -338,6 +339,7 @@ ProcessGroupXCCL::ProcessGroupXCCL(
338339
local_id_(process_group_id++) {
339340
logPrefix_ = createLogPrefix();
340341
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
342+
enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false);
341343
init();
342344
const std::string OFF = "OFF";
343345
std::string torch_distributed_debug =
@@ -349,7 +351,8 @@ ProcessGroupXCCL::ProcessGroupXCCL(
349351
LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
350352
<< "XCCL version: " << XcclVersion
351353
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
352-
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug;
354+
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
355+
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
353356
}
354357

355358
ProcessGroupXCCL::~ProcessGroupXCCL() = default;
@@ -360,6 +363,10 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() {
360363
return seqCollective_;
361364
}
362365

366+
void ProcessGroupXCCL::setEnableNanCheck(bool enableNanCheck) {
367+
enableNanCheck_ = enableNanCheck;
368+
}
369+
363370
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
364371
at::Device& device,
365372
int rank,
@@ -553,7 +560,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
553560
PostProcess post,
554561
OpType opType,
555562
bool asyncOp,
556-
const char* profilingTitle) {
563+
const char* profilingTitle,
564+
bool nanCheck) {
565+
nanCheck &= enableNanCheck_;
557566
seqCollective_++;
558567
auto device = inputs[0].device();
559568
const auto key = std::to_string(device.index());
@@ -620,6 +629,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
620629

621630
c10::OptionalDeviceGuard gpuGuard(device);
622631

632+
if (nanCheck) {
633+
for (const auto& input : inputs) {
634+
checkForNan(input, stream);
635+
}
636+
}
637+
623638
pre(stream, work);
624639

625640
for (const auto i : c10::irange(inputs.size())) {
@@ -697,6 +712,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
697712
auto cclstream = xcclStreamsMap_.at(key).second;
698713
syncStream(device, xcclEventsMap_[key], stream);
699714

715+
if (enableNanCheck_ && opType == OpType::SEND) {
716+
checkForNan(tensor, stream);
717+
}
718+
700719
if (!coalescing_state_) {
701720
auto work =
702721
initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
@@ -1014,6 +1033,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
10141033
"N/A"); // reductionOp
10151034

10161035
const auto root = opts.rootRank;
1036+
bool nanCheck = (rank_ == root);
10171037

10181038
auto outputs = std::vector<at::Tensor>{outputTensor};
10191039
return collective(
@@ -1067,7 +1087,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
10671087
},
10681088
OpType::SCATTER,
10691089
opts.asyncOp,
1070-
"xccl:scatter");
1090+
"xccl:scatter",
1091+
nanCheck);
10711092
}
10721093

10731094
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
@@ -1236,6 +1257,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
12361257
"N/A"); // reductionOp
12371258

12381259
const auto root = opts.rootRank + opts.rootTensor;
1260+
bool nanCheck = (root == rank_);
12391261

12401262
return collective(
12411263
tensor,
@@ -1257,7 +1279,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
12571279
},
12581280
OpType::BROADCAST,
12591281
opts.asyncOp,
1260-
"xccl:broadcast");
1282+
"xccl:broadcast",
1283+
nanCheck);
12611284
}
12621285

12631286
c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
@@ -1270,6 +1293,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
12701293
"Tensor input and output of _broadcast_oop must have the same number of elements ");
12711294
}
12721295
const auto root = opts.rootRank + opts.rootTensor;
1296+
bool nanCheck = (root == rank_);
12731297
return collective(
12741298
inputTensor,
12751299
outputTensor,
@@ -1291,7 +1315,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
12911315
},
12921316
OpType::BROADCAST,
12931317
opts.asyncOp,
1294-
"xccl:_broadcast_oop");
1318+
"xccl:_broadcast_oop",
1319+
nanCheck);
12951320
}
12961321

12971322
c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(

0 commit comments

Comments
 (0)