Skip to content

Commit 8133736

Browse files
authored
Merge pull request #11762 from tensor-tang/refine/set_num_threads
move SetNumThreads to platform
2 parents 4ed0b62 + 2e418a5 commit 8133736

File tree

8 files changed

+100
-19
lines changed

8 files changed

+100
-19
lines changed

paddle/fluid/inference/io.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/block_desc.h"
2121
#include "paddle/fluid/framework/feed_fetch_type.h"
2222
#include "paddle/fluid/framework/op_registry.h"
23-
#include "paddle/fluid/operators/math/blas.h"
23+
#include "paddle/fluid/platform/cpu_helper.h"
2424
#include "paddle/fluid/pybind/pybind.h"
2525

2626
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
@@ -33,7 +33,7 @@ namespace inference {
3333

3434
void Init(const std::vector<std::string> argv) {
3535
framework::InitGflags(argv);
36-
operators::math::SetNumThreads(FLAGS_math_num_threads);
36+
platform::SetNumThreads(FLAGS_math_num_threads);
3737
// init devices
3838
std::vector<int> devices;
3939
std::string token;

paddle/fluid/inference/tests/book/test_inference_nlp.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919
#include "gflags/gflags.h"
2020
#include "gtest/gtest.h"
2121
#include "paddle/fluid/inference/tests/test_helper.h"
22-
#include "paddle/fluid/operators/math/blas.h"
22+
#include "paddle/fluid/platform/cpu_helper.h"
2323
#ifdef PADDLE_WITH_MKLML
2424
#include <omp.h>
2525
#endif
@@ -164,7 +164,7 @@ TEST(inference, nlp) {
164164
// only use 1 thread number per std::thread
165165
omp_set_dynamic(0);
166166
omp_set_num_threads(1);
167-
paddle::operators::math::SetNumThreads(1);
167+
paddle::platform::SetNumThreads(1);
168168
#endif
169169

170170
double start_ms = 0, stop_ms = 0;

paddle/fluid/operators/math/blas.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,6 @@ namespace paddle {
2929
namespace operators {
3030
namespace math {
3131

32-
static void SetNumThreads(int num_threads) {
33-
#ifdef PADDLE_USE_OPENBLAS
34-
int real_num_threads = num_threads > 1 ? num_threads : 1;
35-
openblas_set_num_threads(real_num_threads);
36-
#elif defined(PADDLE_WITH_MKLML)
37-
int real_num_threads = num_threads > 1 ? num_threads : 1;
38-
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
39-
#else
40-
PADDLE_ENFORCE(false, "To be implemented.");
41-
#endif
42-
}
43-
4432
/**
4533
* Matrix Descriptor of a memory buffer.
4634
*

paddle/fluid/platform/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
2828

2929
add_subdirectory(dynload)
3030

31+
cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
32+
cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
33+
3134
IF(WITH_GPU)
3235
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
3336
ELSE()
@@ -43,7 +46,7 @@ ENDIF()
4346
# memcpy depends on device_context, here add deps individually for
4447
# avoiding cycle dependencies
4548
cc_library(device_context SRCS device_context.cc init.cc DEPS malloc
46-
place eigen3 stringpiece ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
49+
place eigen3 stringpiece cpu_helper ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
4750
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
4851

4952
cc_test(init_test SRCS init_test.cc DEPS device_context)

paddle/fluid/platform/cpu_helper.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/cpu_helper.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
#ifdef PADDLE_WITH_MKLML
19+
#include "paddle/fluid/platform/dynload/mklml.h"
20+
#endif
21+
22+
#ifdef PADDLE_USE_OPENBLAS
23+
#include <cblas.h>
24+
#endif
25+
26+
namespace paddle {
27+
namespace platform {
28+
29+
void SetNumThreads(int num_threads) {
30+
#ifdef PADDLE_USE_OPENBLAS
31+
int real_num_threads = num_threads > 1 ? num_threads : 1;
32+
openblas_set_num_threads(real_num_threads);
33+
#elif defined(PADDLE_WITH_MKLML)
34+
int real_num_threads = num_threads > 1 ? num_threads : 1;
35+
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
36+
#else
37+
PADDLE_ENFORCE(false, "To be implemented.");
38+
#endif
39+
}
40+
41+
} // namespace platform
42+
} // namespace paddle

paddle/fluid/platform/cpu_helper.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stddef.h>
18+
19+
namespace paddle {
20+
namespace platform {
21+
22+
//! Set the number of threads in use.
23+
void SetNumThreads(int num_threads);
24+
25+
} // namespace platform
26+
} // namespace paddle
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/cpu_helper.h"
16+
17+
#include "gtest/gtest.h"
18+
19+
TEST(CpuHelper, SetNumThread) {
20+
paddle::platform::SetNumThreads(1);
21+
paddle::platform::SetNumThreads(4);
22+
}

paddle/fluid/platform/init.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include <string>
1818

1919
#include "paddle/fluid/framework/operator.h"
20-
#include "paddle/fluid/operators/math/blas.h"
20+
#include "paddle/fluid/platform/cpu_helper.h"
2121
#include "paddle/fluid/platform/device_context.h"
2222
#include "paddle/fluid/platform/init.h"
2323
#include "paddle/fluid/platform/place.h"
@@ -115,7 +115,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
115115
places.emplace_back(platform::CPUPlace());
116116
platform::DeviceContextPool::Init(places);
117117
#ifndef PADDLE_WITH_MKLDNN
118-
operators::math::SetNumThreads(1);
118+
platform::SetNumThreads(1);
119119
#endif
120120
}
121121

0 commit comments

Comments
 (0)