Skip to content

Commit c6a2675

Browse files
authored
Fix CUDA EP Abs and Sign bfloat16 support (microsoft#23914)
### Description <!-- Describe your changes. --> Abs and Sign had bfloat16 kernels created but not registered with the CUDA EP. Additionally Sign bfloat16 didn't work. * register bfloat16 kernels with CUDA EP * fix incorrectly named macro by adding 'X' as they add bfloat16 registration * add specialization for bfloat16 to _Sign * copied existing pattern. not sure if there's a better way * update tests ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> microsoft#23875
1 parent e495750 commit c6a2675

File tree

6 files changed

+47
-17
lines changed

6 files changed

+47
-17
lines changed

docs/OperatorKernels.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Do not modify directly.*
582582
| Op Name | Parameters | OpSet Version | Types Supported |
583583
|---------|------------|---------------|-----------------|
584584
|**Operator Domain:** *ai.onnx*||||
585-
|Abs|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
585+
|Abs|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
586586
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
587587
|Add|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
588588
|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
@@ -839,7 +839,7 @@ Do not modify directly.*
839839
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
840840
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
841841
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
842-
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
842+
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
843843
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
844844
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
845845
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|

onnxruntime/core/providers/cuda/cu_inc/common.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ __device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }
441441
template <>
442442
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
443443

444+
template <>
445+
__device__ __inline__ BFloat16 _Sign(BFloat16 a) { return _Signum(static_cast<float>(a), std::true_type()); }
446+
444447
template <typename T>
445448
__device__ __inline__ T _Normcdf(T a);
446449

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
10131013
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Abs);
10141014
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Abs);
10151015
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Abs);
1016+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Abs);
10161017
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Neg);
10171018
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Neg);
10181019
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Neg);
@@ -1188,6 +1189,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
11881189
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
11891190
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
11901191
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
1192+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sign);
11911193

11921194
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
11931195
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
@@ -1996,6 +1998,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
19961998
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Abs)>,
19971999
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Abs)>,
19982000
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Abs)>,
2001+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Abs)>,
19992002
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Neg)>,
20002003
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Neg)>,
20012004
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Neg)>,
@@ -2169,6 +2172,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
21692172
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign)>,
21702173
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign)>,
21712174
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,
2175+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sign)>,
21722176

21732177
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
21742178
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,

onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,19 @@ Status IsNaN::ComputeInternal(OpKernelContext* context) const {
213213
UNARY_OP_TYPED(name, ver, float) \
214214
UNARY_OP_TYPED(name, ver, double)
215215

216-
#define UNARY_OP_CSILHFD(name, ver) \
216+
#define UNARY_OP_CSILHFDX(name, ver) \
217217
UNARY_OP_TYPED(name, ver, int8_t) \
218218
UNARY_OP_TYPED(name, ver, int16_t) \
219219
UNARY_OP_TYPED(name, ver, int32_t) \
220220
UNARY_OP_TYPED(name, ver, int64_t) \
221221
UNARY_OP_HFDX(name, ver)
222222

223-
#define UNARY_OP_BWUZCSILHFD(name, ver) \
224-
UNARY_OP_TYPED(name, ver, uint8_t) \
225-
UNARY_OP_TYPED(name, ver, uint16_t) \
226-
UNARY_OP_TYPED(name, ver, uint32_t) \
227-
UNARY_OP_TYPED(name, ver, uint64_t) \
228-
UNARY_OP_CSILHFD(name, ver)
223+
#define UNARY_OP_BWUZCSILHFDX(name, ver) \
224+
UNARY_OP_TYPED(name, ver, uint8_t) \
225+
UNARY_OP_TYPED(name, ver, uint16_t) \
226+
UNARY_OP_TYPED(name, ver, uint32_t) \
227+
UNARY_OP_TYPED(name, ver, uint64_t) \
228+
UNARY_OP_CSILHFDX(name, ver)
229229

230230
UNARY_OP_VERSIONED_BWUZCSILHFD(Abs, 6, 12)
231231
UNARY_OP_VERSIONED_CSILHFD(Neg, 6, 12)
@@ -237,16 +237,16 @@ UNARY_OP_VERSIONED_HFD(Log, 6, 12)
237237
UNARY_OP_VERSIONED_HFD(Exp, 6, 12)
238238
UNARY_OP_VERSIONED_HFD(Erf, 9, 12)
239239

240-
UNARY_OP_BWUZCSILHFD(Abs, 13)
241-
UNARY_OP_CSILHFD(Neg, 13)
240+
UNARY_OP_BWUZCSILHFDX(Abs, 13)
241+
UNARY_OP_CSILHFDX(Neg, 13)
242242
UNARY_OP_HFD(Floor, 13)
243243
UNARY_OP_HFD(Ceil, 13)
244244
UNARY_OP_HFD(Reciprocal, 13)
245245
UNARY_OP_HFDX(Sqrt, 13)
246246
UNARY_OP_HFD(Log, 13)
247247
UNARY_OP_HFDX(Exp, 13)
248248
UNARY_OP_HFDX(Erf, 13)
249-
UNARY_OP_BWUZCSILHFD(Sign, 13)
249+
UNARY_OP_BWUZCSILHFDX(Sign, 13)
250250

251251
UNARY_LOGICALOP_NOT_TYPED(1, bool)
252252
UNARY_OP_HFD(Round, 11)

onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -968,8 +968,15 @@ TEST(MathOpTest, Abs) {
968968
test.Run();
969969
}
970970

971-
#ifdef USE_DNNL
971+
#if defined(USE_CUDA) || defined(USE_DNNL)
972972
TEST(MathOpTest, Abs_bfloat16) {
973+
#ifdef USE_CUDA
974+
int min_cuda_architecture = 530;
975+
if (!HasCudaEnvironment(min_cuda_architecture)) {
976+
LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16";
977+
return;
978+
}
979+
#endif
973980
#ifdef USE_DNNL
974981
if (!DnnlHasBF16Support()) {
975982
LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16";
@@ -980,9 +987,19 @@ TEST(MathOpTest, Abs_bfloat16) {
980987
std::vector<int64_t> dims{2, 2};
981988
test_bf16.AddInput<BFloat16>("X", dims, MakeBFloat16({1.0f, -2.0f, -0.0f, -10.0f}));
982989
test_bf16.AddOutput<BFloat16>("Y", dims, MakeBFloat16({1.0f, 2.0f, 0.0f, 10.0f}));
983-
test_bf16.Run();
990+
991+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
992+
#if defined(USE_CUDA)
993+
execution_providers.push_back(DefaultCudaExecutionProvider());
994+
#endif
995+
996+
#if defined(USE_DNNL)
997+
execution_providers.push_back(DefaultDnnlExecutionProvider());
998+
#endif
999+
1000+
test_bf16.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
9841001
}
985-
#endif // USE_DNNL
1002+
#endif // USE_CUDA || USE_DNNL
9861003

9871004
TEST(MathOpTest, Abs_int8) {
9881005
OpTester test("Abs");

onnxruntime/test/providers/cpu/math/sign_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ TEST(MathOpTest, Sign_MLFloat16) {
207207
// test.Run(OpTester::ExpectResult::kExpectSuccess);
208208
//}
209209

210-
#if defined(USE_DNNL)
210+
#if defined(USE_CUDA) || defined(USE_DNNL)
211211
TEST(MathOpTest, Sign_bfloat16) {
212212
#ifdef USE_DNNL
213213
if (!DnnlHasBF16Support()) {
@@ -228,9 +228,15 @@ TEST(MathOpTest, Sign_bfloat16) {
228228
TestImpl<BFloat16>(input.cbegin(), input.cend(), std::back_inserter(output));
229229
test.AddOutput<BFloat16>("output", input_dims, output);
230230
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
231+
232+
#if defined(USE_CUDA)
233+
execution_providers.push_back(DefaultCudaExecutionProvider());
234+
#endif
235+
231236
#if defined(USE_DNNL)
232237
execution_providers.push_back(DefaultDnnlExecutionProvider());
233-
#endif // USE_DNNL
238+
#endif
239+
234240
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
235241
}
236242
#endif

0 commit comments

Comments
 (0)