Skip to content

Commit e05f4ff

Browse files
reyounghedaoyuan
authored andcommitted
Fix SRL hang when exit. (#291)
* Fix SRL hang when exit. * Error occurred when enable Async Load in TestDataProvider. * It because DataProvider is calling getNextBatchInternal in one thread, and destructing DataProvider in other thread. * Add wait routine in DataProvider destructing. * Also fix another bug, when destructing TestDataProvider and do not read any test data. Fix #286 * Follow comments, Use mutex is cool!
1 parent c64cd6f commit e05f4ff

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
*.pyc
2+
train.log
3+
data/feature
4+
data/conll05st-release/
5+
data/src.dict
6+
data/test.wsj.props
7+
data/test.wsj.seq_pair
8+
data/test.wsj.words
9+
data/tgt.dict
10+
output

paddle/gserver/dataproviders/DataProvider.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ void DoubleBuffer::asyncLoadBatch() {
131131
taskReadySem_.wait();
132132
if (stopping_) break;
133133

134-
while (batchSize_ == 0) {
134+
while (batchSize_ == 0 && !stopping_) {
135135
usleep(5);
136136
}
137+
if (stopping_) break;
137138

138139
do {
139140
DataBatch newBatch;

paddle/gserver/dataproviders/PyDataProvider2.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,26 +433,34 @@ class PyDataProvider2 : public DataProvider {
433433

434434
inline void resetImpl(bool startNewThread) {
435435
DBG << "Reseting " << startNewThread;
436+
exit_.store(true);
436437
if (loadThread_) { // is loading.
437-
exit_.store(true);
438438
loadThread_->join();
439439
loadThread_.reset();
440440
}
441441
{
442442
PyGuard g;
443443
callingContexts_.clear();
444+
this->pullCV_.notify_one();
445+
}
446+
447+
std::lock_guard<std::mutex> guard(mutexForReset_);
448+
{
449+
PyGuard g;
444450
dataPool_.clear();
445451
}
446452
poolActualSize_ = 0;
447-
exit_ = false;
453+
448454
if (startNewThread && cache_->reset()) {
449455
DBG << "Start new thread.";
450456
loadThread_.reset(new std::thread([this] {
457+
exit_ = false;
451458
loadThread();
452459
}));
453460
callingContextCreated_.wait();
454461
}
455462
DBG << "Reset done";
463+
exit_ = false;
456464
}
457465

458466
private:
@@ -465,6 +473,8 @@ class PyDataProvider2 : public DataProvider {
465473
std::condition_variable pullCV_;
466474
std::mutex mtx_;
467475

476+
std::mutex mutexForReset_;
477+
468478
ThreadBarrier callingContextCreated_;
469479
std::unique_ptr<IPyDataProviderCache> cache_;
470480

@@ -529,6 +539,7 @@ class PyDataProvider2 : public DataProvider {
529539
* Loading a batch of data.
530540
*/
531541
int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) {
542+
std::lock_guard<std::mutex> guard(mutexForReset_);
532543
REGISTER_TIMER("PyDP2.getNextBatchInternal")
533544
CHECK_GE(size_, 0);
534545
size_t size = (size_t) size_;
@@ -554,6 +565,10 @@ class PyDataProvider2 : public DataProvider {
554565
} else { // loading from cache.
555566
poolPtr = this->cache_->load();
556567
}
568+
if (exit_) {
569+
// PyDataProvider is destructing.
570+
return 0;
571+
}
557572
CHECK(poolPtr != nullptr);
558573

559574
std::deque<PyObjectPtr>& pool = *poolPtr;

paddle/gserver/tests/test_PyDataProvider2.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,23 @@ TEST(PyDataProvider2, test_check) {
353353
}
354354
}
355355

356+
TEST(PyDataProvider2, multiThread) {
357+
paddle::DataConfig config;
358+
config.set_type("py2");
359+
config.set_files(FLAGS_train_list.c_str());
360+
config.set_load_data_module("test_PyDataProvider2");
361+
config.set_load_data_object("test_dense_no_seq");
362+
config.set_async_load_data(true);
363+
364+
std::unique_ptr<paddle::DataProvider> provider(
365+
paddle::DataProvider::create(config, false));
366+
provider->reset();
367+
paddle::DataBatch batch;
368+
provider->getNextBatch(100, &batch);
369+
provider->reset();
370+
provider.reset();
371+
}
372+
356373
int main(int argc, char** argv) {
357374
testing::InitGoogleTest(&argc, argv);
358375
paddle::initMain(argc, argv);

0 commit comments

Comments
 (0)