Skip to content

Commit e3a9630

Browse files
committed
move SetNumThreads to platform
1 parent b756063 commit e3a9630

File tree

9 files changed

+101
-19
lines changed

9 files changed

+101
-19
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
101101
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
102102
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
103103

104-
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operator)
104+
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece
105+
operator cpu_helper)
105106
cc_test(init_test SRCS init_test.cc DEPS init)
106107

107108
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)

paddle/fluid/framework/init.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818

1919
#include "paddle/fluid/framework/init.h"
2020
#include "paddle/fluid/framework/operator.h"
21-
#include "paddle/fluid/operators/math/blas.h"
21+
#include "paddle/fluid/platform/cpu_helper.h"
2222
#include "paddle/fluid/platform/device_context.h"
2323
#include "paddle/fluid/platform/place.h"
2424
#include "paddle/fluid/string/piece.h"
@@ -115,7 +115,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
115115
places.emplace_back(platform::CPUPlace());
116116
platform::DeviceContextPool::Init(places);
117117
#ifndef PADDLE_WITH_MKLDNN
118-
operators::math::SetNumThreads(1);
118+
platform::SetNumThreads(1);
119119
#endif
120120
}
121121

paddle/fluid/inference/io.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/block_desc.h"
2121
#include "paddle/fluid/framework/feed_fetch_type.h"
2222
#include "paddle/fluid/framework/op_registry.h"
23-
#include "paddle/fluid/operators/math/blas.h"
23+
#include "paddle/fluid/platform/cpu_helper.h"
2424
#include "paddle/fluid/pybind/pybind.h"
2525

2626
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
@@ -33,7 +33,7 @@ namespace inference {
3333

3434
void Init(const std::vector<std::string> argv) {
3535
framework::InitGflags(argv);
36-
operators::math::SetNumThreads(FLAGS_math_num_threads);
36+
platform::SetNumThreads(FLAGS_math_num_threads);
3737
// init devices
3838
std::vector<int> devices;
3939
std::string token;

paddle/fluid/inference/tests/book/test_inference_nlp.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919
#include "gflags/gflags.h"
2020
#include "gtest/gtest.h"
2121
#include "paddle/fluid/inference/tests/test_helper.h"
22-
#include "paddle/fluid/operators/math/blas.h"
22+
#include "paddle/fluid/platform/cpu_helper.h"
2323
#ifdef PADDLE_WITH_MKLML
2424
#include <omp.h>
2525
#endif
@@ -164,7 +164,7 @@ TEST(inference, nlp) {
164164
// only use 1 thread number per std::thread
165165
omp_set_dynamic(0);
166166
omp_set_num_threads(1);
167-
paddle::operators::math::SetNumThreads(1);
167+
paddle::platform::SetNumThreads(1);
168168
#endif
169169

170170
double start_ms = 0, stop_ms = 0;

paddle/fluid/operators/math/blas.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,6 @@ namespace paddle {
4646
namespace operators {
4747
namespace math {
4848

49-
static void SetNumThreads(int num_threads) {
50-
#ifdef PADDLE_USE_OPENBLAS
51-
int real_num_threads = num_threads > 1 ? num_threads : 1;
52-
openblas_set_num_threads(real_num_threads);
53-
#elif defined(PADDLE_WITH_MKLML)
54-
int real_num_threads = num_threads > 1 ? num_threads : 1;
55-
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
56-
#else
57-
PADDLE_ENFORCE(false, "To be implemented.");
58-
#endif
59-
}
60-
6149
/**
6250
* Matrix Descriptor of a memory buffer.
6351
*

paddle/fluid/platform/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
2828

2929
add_subdirectory(dynload)
3030

31+
cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
32+
cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
33+
3134
IF(WITH_GPU)
3235
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
3336
ELSE()

paddle/fluid/platform/cpu_helper.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/cpu_helper.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
#ifdef PADDLE_WITH_MKLML
19+
#include "paddle/fluid/platform/dynload/mklml.h"
20+
#endif
21+
22+
#ifdef PADDLE_USE_OPENBLAS
23+
#include <cblas.h>
24+
#endif
25+
26+
namespace paddle {
27+
namespace platform {
28+
29+
void SetNumThreads(int num_threads) {
30+
#ifdef PADDLE_USE_OPENBLAS
31+
int real_num_threads = num_threads > 1 ? num_threads : 1;
32+
openblas_set_num_threads(real_num_threads);
33+
#elif defined(PADDLE_WITH_MKLML)
34+
int real_num_threads = num_threads > 1 ? num_threads : 1;
35+
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
36+
#else
37+
PADDLE_ENFORCE(false, "To be implemented.");
38+
#endif
39+
}
40+
41+
} // namespace platform
42+
} // namespace paddle

paddle/fluid/platform/cpu_helper.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stddef.h>
18+
19+
namespace paddle {
20+
namespace platform {
21+
22+
//! Set the number of threads in use.
23+
void SetNumThreads(int num_threads);
24+
25+
} // namespace platform
26+
} // namespace paddle
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/cpu_helper.h"
16+
17+
#include "gtest/gtest.h"
18+
19+
TEST(CpuHelper, SetNumThread) {
20+
paddle::platform::SetNumThreads(1);
21+
paddle::platform::SetNumThreads(4);
22+
}

0 commit comments

Comments
 (0)