Skip to content

Commit 227c441

Browse files
add bf16 support for few ops (#20385)
### Description Add bf16 support for below ops: ConstantOfShape Exp Erf convolution PythonOp ### Motivation and Context phimm model works on bf16, ORT need support bf16 on previous ops to work with phimm on bf16
1 parent 464f199 commit 227c441

File tree

15 files changed

+174
-10
lines changed

15 files changed

+174
-10
lines changed

docs/OperatorKernels.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Do not modify directly.*
7171
|ConcatFromSequence|*in* input_sequence:**S**<br> *out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
7272
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|21+|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
7373
|||20|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
74-
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
74+
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
7575
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
7676
|||[1, 10]|**T** = tensor(float)|
7777
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
@@ -601,9 +601,9 @@ Do not modify directly.*
601601
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
602602
|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
603603
|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)|
604-
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
604+
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
605605
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
606-
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
606+
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
607607
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
608608
|Expand|*in* input:**T**<br> *in* shape:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
609609
|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515

1616
namespace onnxruntime {
1717

18+
// Add bf16 support for ConstantOfShape operator for phimm model.
19+
// Although ONNX don't have bf16 support in opset-9 for ConstantOfShape we add support here:
20+
// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#constantofshape-9
1821
using ConstantOfShapeDefaultOutputTypes =
1922
TypeList<
2023
MLFloat16,
2124
float, double,
2225
int8_t, int16_t, int32_t, int64_t,
2326
uint8_t, uint16_t, uint32_t, uint64_t,
24-
bool>;
27+
bool, BFloat16>;
2528

2629
using ConstantOfShapeDefaultOutputTypesOpset20 =
2730
TypeList<
@@ -158,6 +161,7 @@ void ConstantOfShapeBase<EnabledOutputTypeList>::SetValueFromTensorProto(const O
158161
CASE_FETCH_VALUE_DATA(uint16_t)
159162
CASE_FETCH_VALUE_DATA(uint32_t)
160163
CASE_FETCH_VALUE_DATA(uint64_t)
164+
CASE_FETCH_VALUE_DATA(BFloat16)
161165
default:
162166
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
163167
}

onnxruntime/core/providers/cuda/cu_inc/common.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
231231
template <>
232232
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
233233

234+
template <>
235+
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
236+
234237
template <typename T>
235238
__device__ __host__ __inline__ T _Round(T a);
236239

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,9 +1031,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
10311031
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp);
10321032
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp);
10331033
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
1034+
// Add bf16 support for Exp in opset 13+ for phimm model
1035+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
10341036
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf);
10351037
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf);
10361038
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
1039+
// Add bf16 support for Erf in opset 13+ for phimm model
1040+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
10371041
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand);
10381042
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum);
10391043
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max);
@@ -1947,9 +1951,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
19471951
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp)>,
19481952
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp)>,
19491953
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
1954+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
19501955
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf)>,
19511956
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf)>,
19521957
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
1958+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
19531959
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand)>,
19541960
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum)>,
19551961
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max)>,

onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ UNARY_OP_HFD(Ceil, 13)
244244
UNARY_OP_HFD(Reciprocal, 13)
245245
UNARY_OP_HFDX(Sqrt, 13)
246246
UNARY_OP_HFD(Log, 13)
247-
UNARY_OP_HFD(Exp, 13)
248-
UNARY_OP_HFD(Erf, 13)
247+
UNARY_OP_HFDX(Exp, 13)
248+
UNARY_OP_HFDX(Erf, 13)
249249
UNARY_OP_BWUZCSILHFD(Sign, 13)
250250

251251
UNARY_LOGICALOP_NOT_TYPED(1, bool)

onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
8787
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt)
8888
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
8989
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
90-
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
90+
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Erf)
9191
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
9292
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
9393
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)

onnxruntime/core/providers/rocm/cu_inc/common.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
138138
template <>
139139
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
140140

141+
template <>
142+
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
143+
141144
template <typename T>
142145
__device__ __inline__ T _Round(T a);
143146

onnxruntime/core/providers/rocm/rocm_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,9 +1021,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10211021
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp);
10221022
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp);
10231023
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
1024+
// Add bf16 support for Exp in opset 13+ for phimm model
1025+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
10241026
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf);
10251027
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf);
10261028
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
1029+
// Add bf16 support for Erf in opset 13+ for phimm model
1030+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
10271031
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand);
10281032
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum);
10291033
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max);
@@ -1973,9 +1977,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
19731977
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp)>,
19741978
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp)>,
19751979
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
1980+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
19761981
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf)>,
19771982
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf)>,
19781983
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
1984+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
19791985
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand)>,
19801986
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum)>,
19811987
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max)>,

onnxruntime/core/providers/shared_library/provider_bridge_provider.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
478478
template <>
479479
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
480480
template <>
481+
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
482+
template <>
481483
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
482484
template <>
483485
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ struct ProviderHost {
198198
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
199199
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) = 0;
200200
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) = 0;
201+
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) = 0;
201202
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) = 0;
202203
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) = 0;
203204
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) = 0;

0 commit comments

Comments
 (0)