Skip to content

Commit 7fc7d5e

Browse files
author
Erick Muñoz
authored
Enable multithreading on FP16 to FP32 cast operator (microsoft#23619)
### Description Enables multithreading on FP16 to FP32 cast operator. ### Motivation and Context Improves CPU performance on FP16 models that require casting to FP32.
1 parent eceae8b commit 7fc7d5e

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

onnxruntime/core/providers/cpu/tensor/cast_op.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,32 @@ struct TensorCasterNoSat<std::string, DstType> {
254254
// tensor MLFloat16 -> float
255255
template <>
256256
struct TensorCaster<MLFloat16, float> {
257-
void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const {
257+
void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const {
258258
auto out_data = out.MutableData<float>();
259259
auto in_data = in.Data<MLFloat16>();
260260
const size_t shape_size = narrow<size_t>(shape.Size());
261-
MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size);
261+
262+
// Check if the tensor is long enough to use threads
263+
if (shape_size <= 128000) {
264+
MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size);
265+
return;
266+
}
267+
// Calculate the number of compute cyles per implementation
268+
auto cpu_info = CPUIDInfo::GetCPUIDInfo();
269+
double num_compute_cycles;
270+
if (cpu_info.HasSSE3()) {
271+
num_compute_cycles = static_cast<double>(shape_size >> 1);
272+
} else if (cpu_info.HasAVX2()) {
273+
num_compute_cycles = static_cast<double>(shape_size >> 2);
274+
} else {
275+
num_compute_cycles = static_cast<double>(shape_size * 10);
276+
}
277+
278+
concurrency::ThreadPool::TryParallelFor(ctx.GetOperatorThreadPool(), shape_size,
279+
{shape_size * 2.f, shape_size * 4.f, num_compute_cycles},
280+
[in_data, out_data](std::ptrdiff_t first_span, std::ptrdiff_t last_span) {
281+
MlasConvertHalfToFloatBuffer(in_data + first_span, out_data + first_span, static_cast<size_t>(last_span - first_span));
282+
});
262283
}
263284
};
264285

0 commit comments

Comments
 (0)