@@ -187,6 +187,9 @@ namespace dd
187187 }
188188 ++hit;
189189 }
190+
191+ // wait for predict to finish
192+ boost::unique_lock<boost::shared_mutex> lock2 (_train_or_predict_mutex);
190193 }
191194
192195 /* *
@@ -323,7 +326,7 @@ namespace dd
323326 // XXX: due to lock below, queued jobs may not
324327 // start in requested order
325328 boost::unique_lock<boost::shared_mutex> lock (
326- _train_mutex );
329+ _train_or_predict_mutex );
327330 APIData out;
328331 int run_code = this ->train (ad, out);
329332 std::pair<int , APIData> p (local_tcounter,
@@ -336,7 +339,8 @@ namespace dd
336339 }
337340 else
338341 {
339- boost::unique_lock<boost::shared_mutex> lock (_train_mutex);
342+ boost::unique_lock<boost::shared_mutex> lock (
343+ _train_or_predict_mutex);
340344 this ->_has_predict = false ;
341345 int status = this ->train (ad, out);
342346 APIData ad_params_out = ad.getobj (" parameters" ).getobj (" output" );
@@ -498,7 +502,7 @@ namespace dd
498502 oatpp::Object<DTO::PredictBody> predict_job (const APIData &ad,
499503 const bool &chain = false )
500504 {
501- if (!_train_mutex .try_lock_shared ())
505+ if (!_train_or_predict_mutex .try_lock_shared ())
502506 throw MLServiceLockException (
503507 " Predict call while training with an offline learning algorithm" );
504508
@@ -513,13 +517,13 @@ namespace dd
513517 }
514518 catch (std::exception &e)
515519 {
516- _train_mutex .unlock_shared ();
520+ _train_or_predict_mutex .unlock_shared ();
517521 this ->_stats .predict_end (false );
518522 throw ;
519523 }
520524 this ->_stats .predict_end (true );
521525
522- _train_mutex .unlock_shared ();
526+ _train_or_predict_mutex .unlock_shared ();
523527 return out;
524528 }
525529
@@ -533,7 +537,7 @@ namespace dd
533537 _training_jobs; // XXX: the futures' dtor blocks if the object is being
534538 // terminated
535539 std::unordered_map<int , APIData> _training_out;
536- boost::shared_mutex _train_mutex ;
540+ boost::shared_mutex _train_or_predict_mutex ;
537541 };
538542
539543}
0 commit comments