@@ -24,16 +24,16 @@ void ComputeJob(
24
24
const T* bias_data,
25
25
const ptrdiff_t task_idx,
26
26
const int64_t norm_size,
27
- IAllocatorUniquePtr< float >& scale_float_uptr ,
28
- IAllocatorUniquePtr< float >& bias_float_uptr ,
27
+ const float * scale_float_ptr ,
28
+ const float * bias_float_ptr ,
29
29
float epsilon,
30
30
bool simplified,
31
31
T* Y_data,
32
32
U* mean_data,
33
33
U* inv_std_dev_data,
34
34
AllocatorPtr alloc) {
35
- ORT_UNUSED_PARAMETER (scale_float_uptr ); // only used in MLFloat16 overload
36
- ORT_UNUSED_PARAMETER (bias_float_uptr ); // only used in MLFloat16 overload
35
+ ORT_UNUSED_PARAMETER (scale_float_ptr ); // only used in MLFloat16 overload
36
+ ORT_UNUSED_PARAMETER (bias_float_ptr ); // only used in MLFloat16 overload
37
37
ORT_UNUSED_PARAMETER (alloc);
38
38
39
39
const T* p_input = X_data + task_idx * norm_size;
@@ -82,14 +82,17 @@ void ComputeJob(
82
82
const MLFloat16* bias_data,
83
83
const ptrdiff_t task_idx,
84
84
const int64_t norm_size,
85
- IAllocatorUniquePtr< float >& scale_float_uptr ,
86
- IAllocatorUniquePtr< float >& bias_float_uptr ,
85
+ const float * scale_float_ptr ,
86
+ const float * bias_float_ptr ,
87
87
float epsilon,
88
88
bool simplified,
89
89
MLFloat16* Y_data,
90
90
U* mean_data,
91
91
U* inv_std_dev_data,
92
92
AllocatorPtr alloc) {
93
+ ORT_UNUSED_PARAMETER (scale_data); // only used in float/double overload
94
+ ORT_UNUSED_PARAMETER (bias_data); // only used in float/double overload
95
+
93
96
const MLFloat16* p_input = X_data + task_idx * norm_size;
94
97
MLFloat16* p_output = Y_data + task_idx * norm_size;
95
98
@@ -117,22 +120,10 @@ void ComputeJob(
117
120
mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
118
121
}
119
122
120
- if (!scale_float_uptr) {
121
- scale_float_uptr = std::move (input_float_uptr); // overwrite input with scale values, since they have the same size
122
- MlasConvertHalfToFloatBuffer (scale_data, scale_float_uptr.get (), num_elems);
123
- }
124
-
125
- if (bias_data && !bias_float_uptr) {
126
- bias_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
127
- MlasConvertHalfToFloatBuffer (bias_data, bias_float_uptr.get (), num_elems);
128
- }
129
-
130
- const float * scale_float_ptr = scale_float_uptr.get ();
131
- const float * bias_float_ptr = bias_float_uptr.get ();
132
123
for (size_t h = 0 ; h < num_elems; h++) {
133
124
if (simplified) {
134
125
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
135
- } else if (nullptr == bias_data ) {
126
+ } else if (nullptr == bias_float_ptr ) {
136
127
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
137
128
} else {
138
129
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
@@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
166
157
} // namespace
167
158
168
159
LayerNormImpl::LayerNormImpl (const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
169
- : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr ), bias_fp32_(nullptr ) {
160
+ : OpKernel(op_kernel_info),
161
+ simplified_{simplified},
162
+ contrib_op_{contrib_op},
163
+ prepacked_scale_fp32_data_ (nullptr ),
164
+ prepacked_scale_fp32_size_ (0 ),
165
+ prepacked_bias_fp32_data_ (nullptr ),
166
+ prepacked_bias_fp32_size_ (0 ) {
170
167
ORT_ENFORCE (op_kernel_info.GetAttr (" axis" , &axis_).IsOK ());
171
168
ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
172
169
}
@@ -175,15 +172,15 @@ template <typename T, typename U>
175
172
Status LayerNormImpl::ComputeImpl (OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const {
176
173
// Inputs
177
174
const Tensor* X = p_ctx->Input <Tensor>(0 );
178
- const Tensor* scale = p_ctx->Input <Tensor>(1 );
179
- const Tensor* bias = p_ctx->Input <Tensor>(2 );
175
+ const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(1 );
176
+ const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(2 );
180
177
const T* X_data = X->Data <T>();
181
- const T* scale_data = scale->Data <T>();
178
+ const T* scale_data = scale ? scale ->Data <T>() : nullptr ;
182
179
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data <T>();
183
180
184
181
const TensorShape& x_shape = X->Shape ();
185
- const TensorShape& scale_shape = scale->Shape ();
186
- const TensorShape& bias_shape = bias->Shape ();
182
+ size_t scale_size = scale ? static_cast < size_t >(scale ->Shape (). Size ()) : prepacked_scale_fp32_size_ ;
183
+ size_t bias_size = bias ? static_cast < size_t >(bias ->Shape (). Size ()) : prepacked_bias_fp32_size_ ;
187
184
Tensor* Y = p_ctx->Output (0 , x_shape);
188
185
T* Y_data = Y->MutableData <T>();
189
186
@@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
218
215
219
216
AllocatorPtr alloc;
220
217
ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
221
- return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape , bias_data, bias_shape , Y_data, mean_data,
218
+ return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size , bias_data, bias_size , Y_data, mean_data,
222
219
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
223
220
}
224
221
@@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
237
234
238
235
is_packed = false ;
239
236
if (input_idx == 1 ) { // scale
240
- ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, scale_fp32_, is_packed);
237
+ prepacked_scale_fp32_size_ = static_cast <size_t >(tensor.Shape ().Size ());
238
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_scale_fp32_data_, is_packed);
241
239
} else if (input_idx == 2 ) { // bias
242
- ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, bias_fp32_, is_packed);
240
+ prepacked_bias_fp32_size_ = static_cast <size_t >(tensor.Shape ().Size ());
241
+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_bias_fp32_data_, is_packed);
243
242
}
244
243
245
244
return Status::OK ();
@@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext(
250
249
const T* X_data,
251
250
const TensorShape& x_shape,
252
251
const T* scale_data,
253
- const TensorShape& scale_shape ,
252
+ size_t scale_size ,
254
253
const T* bias_data,
255
- const TensorShape& bias_shape ,
254
+ size_t bias_size ,
256
255
T* Y_data,
257
256
U* mean_data,
258
257
U* inv_std_dev_data,
@@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext(
264
263
int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
265
264
int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
266
265
267
- const auto scale_size = scale_shape.Size ();
268
- const auto bias_size = (bias_data) ? bias_shape.Size () : 0 ;
269
- if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
266
+ if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
270
267
return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
271
268
" Size of X.shape()[axis:] == " , norm_size,
272
269
" . Size of scale and bias (if provided) must match this. Got scale size of " ,
273
270
scale_size, " and bias size of " , bias_size);
274
271
}
275
272
273
+ IAllocatorUniquePtr<float > scale_fp32;
274
+ IAllocatorUniquePtr<float > bias_fp32;
275
+ if constexpr (std::is_same_v<T, MLFloat16>) {
276
+ if (prepacked_scale_fp32_data_ == nullptr ) {
277
+ const size_t num_elems = static_cast <size_t >(norm_size);
278
+ scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
279
+ MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
280
+ }
281
+ if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
282
+ const size_t num_elems = static_cast <size_t >(norm_size);
283
+ bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
284
+ MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
285
+ }
286
+ }
287
+
276
288
concurrency::ThreadPool::TryBatchParallelFor (
277
289
thread_pool, static_cast <int32_t >(norm_count),
278
290
[&](ptrdiff_t task_idx) {
279
- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_,
291
+ ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size,
292
+ prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
293
+ prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
280
294
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
281
295
},
282
296
0 );
0 commit comments