Skip to content

Commit 39eb871

Browse files
committed
Add an interface to set the number of threads for math function, and set the default value to 1 for inference.
1 parent 48ac978 commit 39eb871

File tree

5 files changed

+27
-5
lines changed

5 files changed

+27
-5
lines changed

cmake/external/openblas.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ IF(NOT ${CBLAS_FOUND})
2929
"${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}"
3030
CACHE FILEPATH "openblas library." FORCE)
3131

32+
ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS)
33+
3234
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable")
3335
SET(OPENBLAS_COMMIT "v0.2.20")
3436

paddle/fluid/inference/io.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/block_desc.h"
2121
#include "paddle/fluid/framework/feed_fetch_type.h"
2222
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/operators/math/blas.h"
2324
#include "paddle/fluid/pybind/pybind.h"
2425

2526
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
2627
DEFINE_bool(init_p2p, false, "Whether to init p2p.");
28+
DEFINE_int32(math_num_threads, 1,
29+
"Number of threads used to run math functions.");
2730

2831
namespace paddle {
2932
namespace inference {
3033

3134
void Init(const std::vector<std::string> argv) {
3235
framework::InitGflags(argv);
36+
operators::math::SetNumThreads(FLAGS_math_num_threads);
3337
// init devices
3438
std::vector<int> devices;
3539
std::string token;

paddle/fluid/operators/math/blas.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
#ifdef PADDLE_WITH_MKLML
2121
#include <mkl_cblas.h>
2222
#include <mkl_lapacke.h>
23+
#include <mkl_service.h>
2324
#include <mkl_vml_functions.h>
2425
#endif
2526

2627
#ifdef PADDLE_USE_OPENBLAS
2728
#include <cblas.h>
29+
#ifdef LAPACK_FOUND
2830
#include <lapacke.h>
2931
#endif
32+
#endif
3033

3134
#ifndef LAPACK_FOUND
3235
extern "C" {
@@ -46,6 +49,18 @@ namespace paddle {
4649
namespace operators {
4750
namespace math {
4851

52+
static void SetNumThreads(int num_threads) {
53+
#ifdef PADDLE_USE_OPENBLAS
54+
int real_num_threads = num_threads > 1 ? num_threads : 1;
55+
openblas_set_num_threads(real_num_threads);
56+
#elif defined(PADDLE_WITH_MKLML)
57+
int real_num_threads = num_threads > 1 ? num_threads : 1;
58+
mkl_set_num_threads(real_num_threads);
59+
#else
60+
PADDLE_ENFORCE(false, "To be implemented.");
61+
#endif
62+
}
63+
4964
/**
5065
* Matrix Descriptor of a memory buffer.
5166
*

paddle/fluid/operators/math/math_function.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ limitations under the License. */
2121

2222
#ifdef PADDLE_USE_OPENBLAS
2323
#include <cblas.h>
24+
#ifdef LAPACK_FOUND
2425
#include <lapacke.h>
2526
#endif
27+
#endif
2628

2729
#ifndef LAPACK_FOUND
2830
extern "C" {

paddle/math/MathFunctions.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#ifndef MATHFUNCTIONS_H_
16-
#define MATHFUNCTIONS_H_
15+
#pragma once
1716

1817
#ifdef PADDLE_WITH_MKLML
1918
#include <mkl_cblas.h>
2019
#include <mkl_lapacke.h>
2120
#include <mkl_vml_functions.h>
2221
#endif
2322

24-
#if defined(PADDLE_USE_VECLIB)
23+
#ifdef PADDLE_USE_VECLIB
2524
extern "C" {
2625
#include <cblas.h>
2726
#include <clapack.h>
@@ -30,8 +29,10 @@ extern "C" {
3029

3130
#ifdef PADDLE_USE_OPENBLAS
3231
#include <cblas.h>
32+
#ifdef LAPACK_FOUND
3333
#include <lapacke.h>
3434
#endif
35+
#endif
3536

3637
#ifndef LAPACK_FOUND
3738
extern "C" {
@@ -126,5 +127,3 @@ template <class T>
126127
void vTanh(const int n, const T* a, T* r);
127128

128129
} // namespace paddle
129-
130-
#endif // MATHFUNCTIONS_H_

0 commit comments

Comments
 (0)