diff --git a/build/cmake_deps.toml b/build/cmake_deps.toml index 1430ea3a9ef..a7ee84a0677 100644 --- a/build/cmake_deps.toml +++ b/build/cmake_deps.toml @@ -73,6 +73,7 @@ excludes = [ deps = [ "executorch", "executorch_no_prim_ops", + "extension_threadpool", "portable_kernels", ] @@ -197,6 +198,18 @@ deps = [ "executorch", "executorch_no_prim_ops", ] + +[targets.extension_threadpool] +buck_targets = [ + "//extension/threadpool:threadpool", +] +filters = [ + ".cpp$", +] +deps = [ + "executorch", + "executorch_no_prim_ops", +] # ---------------------------------- extension end ---------------------------------- # ---------------------------------- binary start ---------------------------------- @@ -333,6 +346,7 @@ deps = [ "executorch", "executorch_no_prim_ops", "optimized_kernels", + "extension_threadpool", "xnnpack_backend", ] diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index 858e51160e5..70d19343e4a 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -42,7 +42,9 @@ endif() # Build cpublas. list(TRANSFORM _optimized_cpublas__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(cpublas STATIC ${_optimized_cpublas__srcs}) -target_link_libraries(cpublas PRIVATE executorch_no_prim_ops eigen_blas) +target_link_libraries( + cpublas PRIVATE executorch_no_prim_ops eigen_blas extension_threadpool +) target_compile_options(cpublas PUBLIC ${_common_compile_options}) # Generate C++ bindings to register kernels into both PyTorch (for AOT) and @@ -58,7 +60,9 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) -target_link_libraries(optimized_kernels PRIVATE executorch_no_prim_ops cpublas) +target_link_libraries( + optimized_kernels PRIVATE executorch_no_prim_ops cpublas extension_threadpool +) target_compile_options(optimized_kernels PUBLIC ${_common_compile_options}) # Build a library for _optimized_kernels_srcs # diff --git a/kernels/optimized/blas/BlasKernel.h b/kernels/optimized/blas/BlasKernel.h index f594d1748e7..2c03ed0b638 100644 --- a/kernels/optimized/blas/BlasKernel.h +++ b/kernels/optimized/blas/BlasKernel.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -177,34 +178,37 @@ inline void gemm_transa_( torch::executor::BFloat16 beta, torch::executor::BFloat16 *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c -// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { if (alpha == 1 && beta == 0) { - const auto *a_ = a; - for (int i = 0; i < m; ++i) { + executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (int i = begin; i < end; ++i) { + const auto *b_ = b; + for (int j = 0; j < n; ++j) { + const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); + b_ += ldb; + c[j*ldc+i] = dot; + } + a_ += lda; + } + }); + return; + } + executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (int i = begin; i < end; ++i) { const auto *b_ = b; for (int j = 0; j < n; ++j) { const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); b_ += ldb; - c[j*ldc+i] = dot; + if (beta == 0) { + c[j*ldc+i] = alpha*dot; + } else { + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; + } } a_ += lda; } - return; - } - const auto *a_ = a; - for (int i = 0; i < m; ++i) { - const auto *b_ = b; - for (int j = 0; j < n; ++j) { - const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); - b_ += ldb; - if (beta == 0) { - c[j*ldc+i] = alpha*dot; - } else { - c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; - } - } - a_ += lda; - } + }); } #endif diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 16ce446df40..04ee0cfde42 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -157,6 +157,7 @@ def define_libs(): "DEFAULT": [], }), exported_deps = [ + "//executorch/extension/parallel:thread_parallel", "//executorch/kernels/optimized:libutils", "//executorch/runtime/core/exec_aten:lib", ], diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 791c2184e9f..da40c91dab0 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -253,9 +253,12 @@ et_cxx_test( SOURCES ${_optimized_kernels_test_sources} EXTRA_LIBS + cpuinfo + extension_threadpool optimized_kernels optimized_ops_lib portable_kernels + pthreadpool eigen_blas ) add_dependencies(optimized_kernels_test generate_wrapper)