@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " fast_tokenizer/core/base.h"
16
+
16
17
#include < thread>
17
18
18
19
namespace paddlenlp {
@@ -28,16 +29,22 @@ int GetThreadNum() { return fast_tokenizer_thread_num; }
28
29
void RunMultiThread (std::function<void (size_t , size_t )> func,
29
30
size_t batch_size) {
30
31
int thread_num = GetThreadNum ();
31
- std::vector<std::thread> vectorOfThread;
32
- size_t start_index = 0 ;
33
- size_t step_index = ceil (batch_size / float (thread_num));
34
-
35
- for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
36
- vectorOfThread.emplace_back (std::thread (func, start_index, step_index));
37
- start_index = start_index + step_index;
38
- }
39
- for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
40
- vectorOfThread[thread_index].join ();
32
+ if (thread_num == 1 ) {
33
+ // Note(zhoushunjie): No need to create threads when
34
+ // thread_num equals to 1.
35
+ func (0 , batch_size);
36
+ } else {
37
+ std::vector<std::thread> vectorOfThread;
38
+ size_t start_index = 0 ;
39
+ size_t step_index = ceil (batch_size / float (thread_num));
40
+
41
+ for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
42
+ vectorOfThread.emplace_back (std::thread (func, start_index, step_index));
43
+ start_index = start_index + step_index;
44
+ }
45
+ for (size_t thread_index = 0 ; thread_index < thread_num; thread_index++) {
46
+ vectorOfThread[thread_index].join ();
47
+ }
41
48
}
42
49
}
43
50
0 commit comments