From f6ecefa29be1b5d169a43b26f1b722a628f7615d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Sep 2024 16:16:12 -0700 Subject: [PATCH] [ExecuTorch] Use parallel_for in bfloat16 gemm_transa_ kernel The upstream kernel uses this, I just didn't port it at first. Differential Revision: [D62154262](https://our.internmc.facebook.com/intern/diff/D62154262/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.h | 44 ++++++++++++++++------------- kernels/optimized/lib_defs.bzl | 4 ++- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.h b/kernels/optimized/blas/BlasKernel.h index f594d1748e7..2c03ed0b638 100644 --- a/kernels/optimized/blas/BlasKernel.h +++ b/kernels/optimized/blas/BlasKernel.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -177,34 +178,37 @@ inline void gemm_transa_( torch::executor::BFloat16 beta, torch::executor::BFloat16 *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c -// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { if (alpha == 1 && beta == 0) { - const auto *a_ = a; - for (int i = 0; i < m; ++i) { + executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (int i = begin; i < end; ++i) { + const auto *b_ = b; + for (int j = 0; j < n; ++j) { + const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); + b_ += ldb; + c[j*ldc+i] = dot; + } + a_ += lda; + } + }); + return; + } + executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (int i = begin; i < end; ++i) { const auto *b_ = b; for (int j = 0; j < n; ++j) { const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); b_ += ldb; - c[j*ldc+i] = dot; + if (beta == 0) { + c[j*ldc+i] = alpha*dot; + } else { + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; + } } a_ += lda; } - return; - } - const auto *a_ = a; - for (int i = 0; i < m; ++i) { - const auto *b_ = b; - for (int j = 0; j < n; ++j) { - const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); - b_ += ldb; - if (beta == 0) { - c[j*ldc+i] = alpha*dot; - } else { - c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; - } - } - a_ += lda; - } + }); } #endif diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 16ce446df40..23bfda9d5a6 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -155,7 +155,9 @@ def define_libs(): deps = select({ ":linux-x86_64": [mkl_dep] if not runtime.is_oss else [], "DEFAULT": [], - }), + }) + [ + "//executorch/extension/parallel:thread_parallel", + ], exported_deps = [ "//executorch/kernels/optimized:libutils", "//executorch/runtime/core/exec_aten:lib",