Skip to content

Commit 1f295e6

Browse files
authored
Fix ft single thread performance (#4441)
1 parent b2a24c6 commit 1f295e6

File tree

1 file changed

+17
-10
lines changed
  • fast_tokenizer/fast_tokenizer/core

1 file changed

+17
-10
lines changed

fast_tokenizer/fast_tokenizer/core/base.cc

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "fast_tokenizer/core/base.h"
16+
1617
#include <thread>
1718

1819
namespace paddlenlp {
@@ -28,16 +29,22 @@ int GetThreadNum() { return fast_tokenizer_thread_num; }
2829
void RunMultiThread(std::function<void(size_t, size_t)> func,
2930
size_t batch_size) {
3031
int thread_num = GetThreadNum();
31-
std::vector<std::thread> vectorOfThread;
32-
size_t start_index = 0;
33-
size_t step_index = ceil(batch_size / float(thread_num));
34-
35-
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
36-
vectorOfThread.emplace_back(std::thread(func, start_index, step_index));
37-
start_index = start_index + step_index;
38-
}
39-
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
40-
vectorOfThread[thread_index].join();
32+
if (thread_num == 1) {
33+
// Note(zhoushunjie): No need to create threads when
34+
// thread_num equals to 1.
35+
func(0, batch_size);
36+
} else {
37+
std::vector<std::thread> vectorOfThread;
38+
size_t start_index = 0;
39+
size_t step_index = ceil(batch_size / float(thread_num));
40+
41+
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
42+
vectorOfThread.emplace_back(std::thread(func, start_index, step_index));
43+
start_index = start_index + step_index;
44+
}
45+
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
46+
vectorOfThread[thread_index].join();
47+
}
4148
}
4249
}
4350

0 commit comments

Comments
 (0)