Skip to content

Commit 0ef1f46

Browse files
Bycobmergify[bot]
authored andcommitted
fix: prevent crash when a service is deleted before finishing predict
1 parent 2b07002 commit 0ef1f46

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/mlservice.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ namespace dd
187187
}
188188
++hit;
189189
}
190+
191+
// wait for predict to finish
192+
boost::unique_lock<boost::shared_mutex> lock2(_train_or_predict_mutex);
190193
}
191194

192195
/**
@@ -323,7 +326,7 @@ namespace dd
323326
// XXX: due to lock below, queued jobs may not
324327
// start in requested order
325328
boost::unique_lock<boost::shared_mutex> lock(
326-
_train_mutex);
329+
_train_or_predict_mutex);
327330
APIData out;
328331
int run_code = this->train(ad, out);
329332
std::pair<int, APIData> p(local_tcounter,
@@ -336,7 +339,8 @@ namespace dd
336339
}
337340
else
338341
{
339-
boost::unique_lock<boost::shared_mutex> lock(_train_mutex);
342+
boost::unique_lock<boost::shared_mutex> lock(
343+
_train_or_predict_mutex);
340344
this->_has_predict = false;
341345
int status = this->train(ad, out);
342346
APIData ad_params_out = ad.getobj("parameters").getobj("output");
@@ -498,7 +502,7 @@ namespace dd
498502
oatpp::Object<DTO::PredictBody> predict_job(const APIData &ad,
499503
const bool &chain = false)
500504
{
501-
if (!_train_mutex.try_lock_shared())
505+
if (!_train_or_predict_mutex.try_lock_shared())
502506
throw MLServiceLockException(
503507
"Predict call while training with an offline learning algorithm");
504508

@@ -513,13 +517,13 @@ namespace dd
513517
}
514518
catch (std::exception &e)
515519
{
516-
_train_mutex.unlock_shared();
520+
_train_or_predict_mutex.unlock_shared();
517521
this->_stats.predict_end(false);
518522
throw;
519523
}
520524
this->_stats.predict_end(true);
521525

522-
_train_mutex.unlock_shared();
526+
_train_or_predict_mutex.unlock_shared();
523527
return out;
524528
}
525529

@@ -533,7 +537,7 @@ namespace dd
533537
_training_jobs; // XXX: the futures' dtor blocks if the object is being
534538
// terminated
535539
std::unordered_map<int, APIData> _training_out;
536-
boost::shared_mutex _train_mutex;
540+
boost::shared_mutex _train_or_predict_mutex;
537541
};
538542

539543
}

0 commit comments

Comments
 (0)