2
2
// Licensed under the MIT License.
3
3
4
4
#include " core/framework/tensor.h"
5
+ #include " core/mlas/inc/mlas.h"
5
6
#include " core/util/math_cpuonly.h"
6
7
#include " core/providers/common.h"
7
8
#include " core/platform/threadpool.h"
@@ -36,52 +37,188 @@ REGISTER_KERNEL_TYPED(float)
36
37
REGISTER_KERNEL_TYPED(double )
37
38
REGISTER_KERNEL_TYPED(MLFloat16)
38
39
39
- // Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
40
- template <typename T, typename Ret>
41
- ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);
42
-
43
- template <>
44
- ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float >(MLFloat16 val) {
45
- return val.ToFloat ();
46
- }
47
-
48
- template <>
49
- ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double >(MLFloat16 val) {
50
- return static_cast <double >(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float >(val));
40
+ namespace {
41
+
42
+ template <typename T, typename = std::enable_if_t <std::is_same_v<T, float > || std::is_same_v<T, double >, void >>
43
+ void ComputeJob (
44
+ const T* input_data,
45
+ const T* skip_data,
46
+ const T* gamma_data,
47
+ const T* beta_data,
48
+ const T* bias_data,
49
+ IAllocatorUniquePtr<float >& skip_float_uptr,
50
+ IAllocatorUniquePtr<float >& gamma_float_uptr,
51
+ IAllocatorUniquePtr<float >& beta_float_uptr,
52
+ IAllocatorUniquePtr<float >& bias_float_uptr,
53
+ ptrdiff_t task_idx,
54
+ int hidden_size,
55
+ int64_t skip_size,
56
+ float epsilon,
57
+ bool simplified,
58
+ T* output_data,
59
+ T* skip_input_bias_add_output_data,
60
+ AllocatorPtr alloc) {
61
+ ORT_UNUSED_PARAMETER (skip_float_uptr); // only used in MLFloat16 overload
62
+ ORT_UNUSED_PARAMETER (gamma_float_uptr); // only used in MLFloat16 overload
63
+ ORT_UNUSED_PARAMETER (beta_float_uptr); // only used in MLFloat16 overload
64
+ ORT_UNUSED_PARAMETER (bias_float_uptr); // only used in MLFloat16 overload
65
+ ORT_UNUSED_PARAMETER (alloc);
66
+
67
+ auto offset = task_idx * hidden_size;
68
+ const T* p_input = input_data + offset;
69
+ const T* p_skip = skip_data + (offset % skip_size);
70
+ T* p_output = output_data + offset;
71
+ T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;
72
+
73
+ T mean (0 .0f );
74
+ T mean_square (0 .0f );
75
+
76
+ for (decltype (hidden_size) h = 0 ; h < hidden_size; h++) {
77
+ T val = p_input[h] + p_skip[h];
78
+
79
+ if (nullptr != bias_data) {
80
+ val += bias_data[h];
81
+ }
82
+
83
+ if (nullptr != p_skip_input_bias_add_output) {
84
+ p_skip_input_bias_add_output[h] = val;
85
+ }
86
+
87
+ p_output[h] = val;
88
+ mean += val;
89
+ mean_square += val * val;
90
+ }
91
+
92
+ mean = mean / hidden_size;
93
+ if (simplified) {
94
+ mean_square = sqrt (mean_square / hidden_size + epsilon);
95
+ } else {
96
+ mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon);
97
+ }
98
+
99
+ for (decltype (hidden_size) h = 0 ; h < hidden_size; h++) {
100
+ if (simplified) {
101
+ p_output[h] = p_output[h] / mean_square * gamma_data[h];
102
+ } else if (nullptr == beta_data) {
103
+ p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
104
+ } else {
105
+ p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
106
+ }
107
+ }
51
108
}
52
109
53
- template <>
54
- ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float , float >(float val) {
55
- return val;
110
+ void ComputeJob (
111
+ const MLFloat16* input_data,
112
+ const MLFloat16* skip_data,
113
+ const MLFloat16* gamma_data,
114
+ const MLFloat16* beta_data,
115
+ const MLFloat16* bias_data,
116
+ IAllocatorUniquePtr<float >& skip_float_uptr,
117
+ IAllocatorUniquePtr<float >& gamma_float_uptr,
118
+ IAllocatorUniquePtr<float >& beta_float_uptr,
119
+ IAllocatorUniquePtr<float >& bias_float_uptr,
120
+ ptrdiff_t task_idx,
121
+ int hidden_size,
122
+ int64_t skip_size,
123
+ float epsilon,
124
+ bool simplified,
125
+ MLFloat16* output_data,
126
+ MLFloat16* skip_input_bias_add_output_data,
127
+ AllocatorPtr alloc) {
128
+ auto offset = task_idx * hidden_size;
129
+ const MLFloat16* p_input = input_data + offset;
130
+ const MLFloat16* p_skip = skip_data + (offset % skip_size);
131
+ MLFloat16* p_output = output_data + offset;
132
+ MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;
133
+
134
+ float mean (0 .0f );
135
+ float mean_square (0 .0f );
136
+ const size_t num_elems = static_cast <size_t >(hidden_size);
137
+
138
+ IAllocatorUniquePtr<float > input_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
139
+ MlasConvertHalfToFloatBuffer (p_input, input_float_uptr.get (), num_elems);
140
+
141
+ if (!skip_float_uptr) {
142
+ skip_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
143
+ MlasConvertHalfToFloatBuffer (p_skip, skip_float_uptr.get (), num_elems);
144
+ }
145
+
146
+ if (bias_data && !bias_float_uptr) {
147
+ bias_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
148
+ MlasConvertHalfToFloatBuffer (bias_data, bias_float_uptr.get (), num_elems);
149
+ }
150
+
151
+ IAllocatorUniquePtr<float > output_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
152
+ float * output_float_ptr = output_float_uptr.get ();
153
+
154
+ const float * input_float_ptr = input_float_uptr.get ();
155
+ const float * skip_float_ptr = skip_float_uptr.get ();
156
+ const float * bias_float_ptr = bias_float_uptr.get ();
157
+ for (size_t h = 0 ; h < num_elems; h++) {
158
+ float val = input_float_ptr[h] + skip_float_ptr[h];
159
+
160
+ if (bias_float_uptr) {
161
+ val += bias_float_ptr[h];
162
+ }
163
+
164
+ output_float_ptr[h] = val;
165
+ mean += val;
166
+ mean_square += val * val;
167
+ }
168
+
169
+ if (nullptr != p_skip_input_bias_add_output) {
170
+ MlasConvertFloatToHalfBuffer (output_float_ptr, p_skip_input_bias_add_output, num_elems);
171
+ }
172
+
173
+ mean = mean / hidden_size;
174
+ if (simplified) {
175
+ mean_square = sqrt (mean_square / hidden_size + epsilon);
176
+ } else {
177
+ mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon);
178
+ }
179
+
180
+ if (!gamma_float_uptr) {
181
+ gamma_float_uptr = std::move (input_float_uptr); // overwrite input with gamma values, since they have the same size
182
+ MlasConvertHalfToFloatBuffer (gamma_data, gamma_float_uptr.get (), num_elems);
183
+ }
184
+
185
+ if (beta_data && !beta_float_uptr) {
186
+ beta_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
187
+ MlasConvertHalfToFloatBuffer (beta_data, beta_float_uptr.get (), num_elems);
188
+ }
189
+
190
+ const float * gamma_float_ptr = gamma_float_uptr.get ();
191
+ const float * beta_float_ptr = beta_float_uptr.get ();
192
+ for (size_t h = 0 ; h < num_elems; h++) {
193
+ if (simplified) {
194
+ output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
195
+ } else if (nullptr == beta_float_uptr) {
196
+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
197
+ } else {
198
+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
199
+ }
200
+ }
201
+
202
+ MlasConvertFloatToHalfBuffer (output_float_ptr, p_output, num_elems);
56
203
}
57
204
58
- template <>
59
- ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double , double >(double val) {
60
- return val;
61
- }
205
+ void ConvertMLFloat16ToFloatIfNeeded (const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float >& dest, bool & is_packed) {
206
+ if (tensor.GetElementType () == utils::ToTensorProtoElementType<MLFloat16>()) {
207
+ auto tensor_data_ptr = tensor.Data <MLFloat16>();
208
+ auto tensor_size = static_cast <size_t >(tensor.Shape ().Size ());
209
+ auto float_ptr = IAllocator::MakeUniquePtr<float >(alloc, tensor_size, true );
62
210
63
- // Function template that only converts the input value to MLFloat16 if T is MLFloat16.
64
- template <typename T>
65
- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, float > || std::is_same_v<T, double >, T>
66
- ConvertDoubleOrFloatToMLFloat16IfNeeded (T val) {
67
- return val;
211
+ MlasConvertHalfToFloatBuffer (tensor_data_ptr, float_ptr.get (), tensor_size);
212
+ dest = std::move (float_ptr);
213
+ is_packed = true ;
214
+ }
68
215
}
69
216
70
- template <typename T>
71
- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, MLFloat16>, T>
72
- ConvertDoubleOrFloatToMLFloat16IfNeeded (float val) {
73
- return MLFloat16 (val);
74
- }
75
-
76
- template <typename T>
77
- ORT_FORCEINLINE constexpr typename std::enable_if_t <std::is_same_v<T, MLFloat16>, T>
78
- ConvertDoubleOrFloatToMLFloat16IfNeeded (double val) {
79
- return MLFloat16 (static_cast <float >(val));
80
- }
217
+ } // namespace
81
218
82
219
template <typename T, bool simplified>
83
220
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
84
- : OpKernel(op_kernel_info) {
221
+ : OpKernel(op_kernel_info), skip_fp32_( nullptr ), gamma_fp32_( nullptr ), beta_fp32_( nullptr ), bias_fp32_( nullptr ) {
85
222
ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
86
223
ORT_ENFORCE (epsilon_ >= 0 );
87
224
}
@@ -94,8 +231,7 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
94
231
const Tensor* beta = p_ctx->Input <Tensor>(3 );
95
232
const Tensor* bias = p_ctx->Input <Tensor>(4 );
96
233
Tensor* output = p_ctx->Output (0 , input->Shape ());
97
- // For inferencing, we support one more optional output which is the sum
98
- // of the input and skip tensors
234
+ // For inferencing, we support one more optional output which is the sum of the input and skip tensors
99
235
Tensor* skip_input_bias_add_output = p_ctx->Output (3 , input->Shape ());
100
236
101
237
const auto & input_dims = input->Shape ().GetDims ();
@@ -120,75 +256,44 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
120
256
121
257
T* output_data = output->MutableData <T>();
122
258
123
- // For inferencing, we support one more optional output which is the sum
124
- // of the input and skip tensors
125
- T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData <T>() : nullptr ;
259
+ // For inferencing, we support one more optional output which is the sum of the input and skip tensors
260
+ T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData <T>();
126
261
127
- const auto & skip_size = skip->Shape ().Size ();
262
+ const int64_t & skip_size = skip->Shape ().Size ();
263
+
264
+ AllocatorPtr alloc;
265
+ ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
128
266
129
267
concurrency::ThreadPool::TryBatchParallelFor (
130
268
p_ctx->GetOperatorThreadPool (), static_cast <int32_t >(task_count),
131
269
[&](ptrdiff_t task_idx) {
132
- auto offset = task_idx * hidden_size;
133
-
134
- const T* p_input = input_data + offset;
135
- const T* p_skip = skip_data + (offset % skip_size);
136
- T* p_output = output_data + offset;
137
- T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr ;
138
-
139
- using DoubleOrFloat = typename std::conditional<
140
- std::is_same<T, double >::value, // If T is double
141
- double , // Use double
142
- float // Otherwise, use float (covers float and MLFloat16)
143
- >::type;
144
-
145
- DoubleOrFloat mean (0 .0f );
146
- DoubleOrFloat mean_square (0 .0f );
147
-
148
- std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
149
- for (size_t h = 0 ; h < static_cast <size_t >(hidden_size); h++) {
150
- DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
151
- DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);
152
-
153
- DoubleOrFloat value = input_value + skip_value;
154
-
155
- if (nullptr != bias_data) {
156
- value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
157
- }
158
-
159
- output_buffer[h] = value;
160
- T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
161
- if (nullptr != p_skip_input_bias_add_output_data) {
162
- p_skip_input_bias_add_output_data[h] = converted_value;
163
- }
164
-
165
- mean += value;
166
- mean_square += value * value;
167
- }
168
-
169
- mean = mean / hidden_size;
170
- if (simplified) {
171
- mean_square = sqrt (mean_square / hidden_size + epsilon_);
172
- } else {
173
- mean_square = sqrt (mean_square / hidden_size - mean * mean + epsilon_);
174
- }
175
-
176
- for (size_t h = 0 ; h < static_cast <size_t >(hidden_size); h++) {
177
- DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
178
- if (simplified) {
179
- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
180
- } else if (nullptr == beta_data) {
181
- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
182
- } else {
183
- DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
184
- p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
185
- }
186
- }
270
+ ComputeJob (input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
271
+ bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
272
+ skip_input_bias_add_output_data, alloc);
187
273
},
188
274
0 );
189
275
190
276
return Status::OK ();
191
277
}
192
278
279
+ template <typename T, bool simplified>
280
+ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
281
+ bool & is_packed, PrePackedWeights* prepacked_weights) {
282
+ ORT_UNUSED_PARAMETER (prepacked_weights);
283
+
284
+ is_packed = false ;
285
+ if (input_idx == 1 ) { // skip
286
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, skip_fp32_, is_packed);
287
+ } else if (input_idx == 2 ) { // gamma
288
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, gamma_fp32_, is_packed);
289
+ } else if (input_idx == 3 ) { // beta
290
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, beta_fp32_, is_packed);
291
+ } else if (input_idx == 4 ) { // bias
292
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, bias_fp32_, is_packed);
293
+ }
294
+
295
+ return Status::OK ();
296
+ }
297
+
193
298
} // namespace contrib
194
299
} // namespace onnxruntime
0 commit comments