Skip to content

Commit 8369129

Browse files
committed
Changes to PSparamManager updating of global model:
The merging of the worker models into the global model now uses a new virtual function of ParamInterface that takes a pointer array (which defaults to a simple loop). This provides support for more sophisticated merge strategies. The merge is now performed into a temporary object that is then moved to replace the global model. The workers are only locked while the pointers are updated, improving parallelization.
1 parent 9f4cecd commit 8369129

File tree

4 files changed

+29
-6
lines changed

4 files changed

+29
-6
lines changed

include/chimbuko/param.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ namespace chimbuko {
6060
*/
6161
virtual void update(const ParamInterface &other) = 0;
6262

63+
/**
64+
* @brief Update the internal run statistics with those from multiple other instances
65+
*
66+
* The instance will be dynamically cast to the derived type internally, and will throw an error if the types do not match
67+
* The other instance will be locked during the process for thread safety
68+
*/
69+
virtual void update(const std::vector<ParamInterface*> &other);
6370

6471
/**
6572
* @brief Set the internal run statistics to match those included in the serialized input map. Overwrite performed only for those keys in input.

include/chimbuko/pserver/PSparamManager.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ namespace chimbuko{
8686
ParamType & getWorkerParams(const int i){ return dynamic_cast<ParamType&>(*m_worker_params[i]); }
8787

8888
private:
89+
std::string m_ad_algorithm; /**< The AD algorithm*/
8990
int m_agg_freq_ms; /**< The frequence in ms at which the global model is updated. Default 1000ms*/
9091
ParamInterface *m_global_params; /**< The global model*/
9192
std::string m_latest_global_params; /**< Cache of the serialized form the the latest global model*/

src/param.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ ParamInterface *ParamInterface::set_AdParam(const std::string & ad_algorithm) {
2525
fatal_error("Invalid algorithm: \"" + ad_algorithm + "\". Available options: HBOS, SSTD, COPOD");
2626
}
2727
}
28+
29+
void ParamInterface::update(const std::vector<ParamInterface*> &other){
30+
for(auto p : other){
31+
this->update(*p);
32+
}
33+
}

src/pserver/PSparamManager.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
using namespace chimbuko;
66

7-
PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorithm): m_agg_freq_ms(1000), m_updater_thread(nullptr), m_worker_params(nworker,nullptr), m_global_params(nullptr), m_updater_exit(false), m_force_update(false){
7+
PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorithm): m_agg_freq_ms(1000), m_updater_thread(nullptr), m_worker_params(nworker,nullptr), m_global_params(nullptr), m_updater_exit(false), m_force_update(false), m_ad_algorithm(ad_algorithm){
88
for(int i=0;i<nworker;i++)
99
m_worker_params[i] = ParamInterface::set_AdParam(ad_algorithm);
1010
m_global_params = ParamInterface::set_AdParam(ad_algorithm);
@@ -13,11 +13,20 @@ PSparamManager::PSparamManager(const int nworker, const std::string &ad_algorith
1313

1414
void PSparamManager::updateGlobalModel(){
1515
verboseStream << "PSparamManager::updateGlobalModel updating global model" << std::endl;
16-
std::unique_lock<std::shared_mutex> _(m_mutex); //unique lock to prevent read/write from other threads
17-
m_global_params->clear(); //reset the global params and reform from worker params which have been aggregating since the start of the run
18-
for(auto p: m_worker_params)
19-
m_global_params->update(*p); //locks the worker params temporarily
20-
m_latest_global_params = m_global_params->serialize();
16+
17+
//Avoid needing to lock out worker threads while updating by merging into a new location and moving after
18+
ParamInterface* new_glob_params = ParamInterface::set_AdParam(m_ad_algorithm);
19+
new_glob_params->update(m_worker_params);
20+
std::string new_glob_params_ser = new_glob_params->serialize();
21+
22+
ParamInterface *tmp;
23+
{
24+
std::unique_lock<std::shared_mutex> _(m_mutex); //unique lock to prevent read/write from other threads
25+
tmp = m_global_params;
26+
m_global_params = new_glob_params;
27+
m_latest_global_params = std::move(new_glob_params_ser);
28+
}
29+
delete tmp;
2130
}
2231

2332

0 commit comments

Comments
 (0)