Skip to content

Commit a2dbc19

Browse files
committed
The pserver now (optionally) prunes the provDB
To enable the above, the pserver now requires the full algo_params JSON file to be passed in vs just the algorithm name. Reordered the arguments to PSshardProvenanceDBclient::connectShard to put shard index first Fixed bug in AlgoParams JSON parse where the checking of present arguments was being performed incorrectly for algorithms other than the HBOS default
1 parent 0b63ba6 commit a2dbc19

File tree

10 files changed

+61
-22
lines changed

10 files changed

+61
-22
lines changed

app/pserver.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
#ifdef ENABLE_PROVDB
1212
#include <chimbuko/core/pserver/PSglobalProvenanceDBclient.hpp>
13+
#include <chimbuko/core/pserver/PSshardProvenanceDBclient.hpp>
1314
#endif
1415

1516
#include <chimbuko/core/param/sstd_param.hpp>
1617
#include <chimbuko/core/util/commandLineParser.hpp>
1718
#include <chimbuko/core/util/error.hpp>
1819
#include <fstream>
19-
#include "chimbuko/core/verbose.hpp"
20+
#include <chimbuko/core/verbose.hpp>
21+
#include <chimbuko/core/ad/ADOutlier.hpp>
2022

2123
#include <chimbuko/modules/factory.hpp>
2224

@@ -40,11 +42,12 @@ struct pserverArgs{
4042
int stat_send_freq;
4143

4244
std::string stat_outputdir;
43-
std::string ad;
4445

4546
int model_update_freq; //frequency in ms at which the global model is updated
4647
bool model_force_update; //force the global model to be updated every time a worker thread updates its model
4748

49+
ADOutlier::AlgoParams algo_params; //The AD algorithm hyperparameters and type/name
50+
4851
#ifdef _USE_ZMQNET
4952
int max_pollcyc_msg;
5053
int zmq_io_thr;
@@ -53,25 +56,27 @@ struct pserverArgs{
5356

5457
#ifdef ENABLE_PROVDB
5558
std::string provdb_addr_dir;
59+
int nprovdb_shards; /**< Number of database shards*/
60+
int nprovdb_instances; /**< Number of instances of the provenance database server*/
5661
std::string provdb_mercury_auth_key; //An authorization key for initializing Mercury (optional, default "")
62+
bool provdb_post_prune; //perform post-pruning on the provenance database
5763
#endif
5864

5965
std::string prov_outputpath;
6066

61-
pserverArgs(): ad("hbos"), nt(-1), logdir("."), ws_addr(""), load_params_set(false), save_params_set(false), freeze_params(false), stat_send_freq(1000), stat_outputdir(""), port(5559), prov_outputpath(""), model_update_freq(1000), model_force_update(false)
67+
pserverArgs(): nt(-1), logdir("."), ws_addr(""), load_params_set(false), save_params_set(false), freeze_params(false), stat_send_freq(1000), stat_outputdir(""), port(5559), prov_outputpath(""), model_update_freq(1000), model_force_update(false)
6268
#ifdef _USE_ZMQNET
6369
, max_pollcyc_msg(10), zmq_io_thr(1), autoshutdown(true)
6470
#endif
6571
#ifdef ENABLE_PROVDB
66-
, provdb_addr_dir(""), provdb_mercury_auth_key("")
72+
, provdb_addr_dir(""), provdb_mercury_auth_key(""), provdb_post_prune(true), nprovdb_shards(1), nprovdb_instances(1)
6773
#endif
6874
{}
6975

7076
static commandLineParser &getParser(pserverArgs &instance){
7177
static bool init = false;
7278
static commandLineParser p;
7379
if(!init){
74-
addOptionalCommandLineArg(p, instance, ad, "Set AD algorithm to use.");
7580
addOptionalCommandLineArg(p, instance, nt, "Set the number of RPC handler threads (max-2 by default)");
7681
addOptionalCommandLineArg(p, instance, logdir, "Set the output log directory (default: job directory)");
7782
addOptionalCommandLineArg(p, instance, port, "Set the pserver port (default: 5559)");
@@ -89,11 +94,15 @@ struct pserverArgs{
8994
#ifdef ENABLE_PROVDB
9095
addOptionalCommandLineArg(p, instance, provdb_addr_dir, "The directory containing the address file written out by the provDB server. An empty string will disable the connection to the global DB. (default empty, disabled)");
9196
addOptionalCommandLineArg(p, instance, provdb_mercury_auth_key, "Set the Mercury authorization key for connection to the provDB (default \"\")");
97+
addOptionalCommandLineArg(p, instance, provdb_post_prune, "If enabled the pserver will automatically \"prune\" the provenance database at the end of the run (default: true)");
98+
addOptionalCommandLineArgWithDefault(p, instance, nprovdb_shards, 1, "Number of provenance database shards. Clients connect to shards round-robin by rank (default 1)");
99+
addOptionalCommandLineArgWithDefault(p, instance, nprovdb_instances, 1, "Number of provenance database instances. Shards are divided uniformly over instances. (default 1)");
92100
#endif
93101
addOptionalCommandLineArg(p, instance, prov_outputpath, "Output global provenance data to this directory. Can be used in place of or in conjunction with the provenance database. An empty string \"\" (default) disables this output");
94102
addOptionalCommandLineArg(p, instance, model_update_freq, "The frequency in ms at which the global AD model is updated (default 1000ms)");
95103
addOptionalCommandLineArg(p, instance, model_force_update, "Force the global AD model to be updated every time a worker thread updates its model (default false)");
96104

105+
p.addOptionalArg(new ADOutlier::AlgoParams::cmdlineParser(instance.algo_params, "-algo_params_file", "Set the filename containing the algorithm name and hyperparameters (ensure consistent with OAD)."));
97106
init = true;
98107
}
99108
return p;
@@ -132,7 +141,7 @@ int main (int argc, char ** argv){
132141
enableVerboseLogging() = true;
133142
}
134143

135-
PSparamManager param(args.nt, args.ad); //the AD model; independent models for each worker thread that are aggregated periodically to a global model
144+
PSparamManager param(args.nt, args.algo_params.algorithm); //the AD model; independent models for each worker thread that are aggregated periodically to a global model
136145
param.enableForceUpdate(args.model_force_update); //decide whether the model is forced to be updated every time a worker updates its model
137146

138147
std::unique_ptr<ProvDBmoduleSetupCore> pdb_setup = modules::factoryInstantiateProvDBmoduleSetup(module);
@@ -272,6 +281,22 @@ int main (int argc, char ** argv){
272281
ps_module_data_man->writeModel(args.save_params, param);
273282
}
274283

284+
285+
//Post-prune the provenance database
286+
if(args.provdb_post_prune && args.provdb_addr_dir.size()){
287+
progressStream << "PServer: Pruning the provenance database" << std::endl;
288+
std::unique_ptr<ProvDBpruneCore> pruner = modules::factoryInstantiateProvDBprune(module, args.algo_params, param.getGlobalParamsCopy()->serialize());
289+
290+
for(int s=0;s<args.nprovdb_shards;s++){
291+
progressStream << "PServer: Pruning shard " << s+1 << " of " << args.nprovdb_shards << std::endl;
292+
PSshardProvenanceDBclient shard_client(pdb_setup->getMainDBcollections());
293+
shard_client.connectShard(s, args.provdb_addr_dir, args.nprovdb_shards, args.nprovdb_instances);
294+
pruner->prune(shard_client.getDatabase());
295+
}
296+
progressStream << "PServer: Updating global function stats" << std::endl;
297+
pruner->finalize(provdb_client.getDatabase());
298+
}
299+
275300
progressStream << "Pserver: finished" << std::endl;
276301

277302
return 0;

include/chimbuko/core/provdb/setup.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ namespace chimbuko{
7474
*/
7575
inline int getNshards() const{ return m_nshards; }
7676

77+
/**
78+
* @brief Return the total number of instances
79+
*/
80+
inline int getNinstance() const{ return m_ninstances; }
81+
7782
/**
7883
* @brief Get the number of shards serviced by a given instance
7984
*/

include/chimbuko/core/pserver/PSshardProvenanceDBclient.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace chimbuko{
2121
* @brief Connect the client the appropriate provenance database server instance / shard using the default setup
2222
* @param addr_file_dir The directory containing the address files created by the provDB
2323
*/
24-
void connectShard(const std::string &addr_file_dir, int shard, int nshards, int ninstances);
24+
void connectShard(int shard, const std::string &addr_file_dir, const int nshards, const int ninstances);
2525

2626
/**
2727
* @brief No handshake is needed

scripts/launch/run_services.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ if (( ${use_provdb} == 1 )); then
204204
done
205205

206206
extra_args+=" -provdb_addr_dir ${provdb_addr_dir} -nprovdb_instances ${provdb_ninstances} -nprovdb_shards ${provdb_nshards}"
207-
ps_extra_args+=" -provdb_addr_dir ${provdb_addr_dir}"
207+
ps_extra_args+=" -provdb_addr_dir ${provdb_addr_dir} -nprovdb_instances ${provdb_ninstances} -nprovdb_shards ${provdb_nshards}"
208208
echo "Chimbuko Services: Enabling provenance database with arg: ${extra_args}"
209209
cd -
210210

@@ -370,7 +370,7 @@ if (( ${use_pserver} == 1 )); then
370370

371371
pserver_alg=${ad_alg} #Pserver AD algorithm choice must match that used for the driver
372372
pserver_addr="tcp://${ip}:${pserver_port}" #address for parameter server in format "tcp://IP:PORT"
373-
cmd="pserver ${module} -ad ${pserver_alg} -nt ${pserver_nt} -logdir ${log_dir} -port ${pserver_port} -save_params ${ps_dir}/global_model.json ${ps_extra_args} 2>&1 | tee ${log_dir}/pserver.log &"
373+
cmd="pserver ${module} -algo_params_file ${algo_params_file} -nt ${pserver_nt} -logdir ${log_dir} -port ${pserver_port} -save_params ${ps_dir}/global_model.json ${ps_extra_args} 2>&1 | tee ${log_dir}/pserver.log &"
374374
if [[ ! -z "${pserver_numa_bind:-}" ]]; then
375375
echo "Chimbuko Services binding pserver to NUMA domain ${pserver_numa_bind}"
376376
cmd="numactl -N ${pserver_numa_bind} ${cmd}"

src/core/ad/ADOutlier.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ void ADOutlier::AlgoParams::setJson(const nlohmann::json &in){
2323
#define JSON_GET(to) if(in.contains(#to)) to = in[#to].template get<decltype(to)>()
2424
//Check for required
2525
JSON_CHECK(algorithm);
26+
JSON_GET(algorithm);
27+
2628
if(algorithm == "sstd"){
2729
JSON_CHECK(sstd_sigma);
2830
}else if(algorithm == "hbos"){
@@ -33,7 +35,6 @@ void ADOutlier::AlgoParams::setJson(const nlohmann::json &in){
3335
JSON_CHECK(hbos_thres);
3436
}
3537
//Get all available
36-
JSON_GET(algorithm);
3738
JSON_GET(sstd_sigma);
3839
JSON_GET(glob_thres);
3940
JSON_GET(hbos_max_bins);
@@ -62,9 +63,7 @@ nlohmann::json ADOutlier::AlgoParams::getJson() const{
6263
}
6364

6465
int ADOutlier::AlgoParams::cmdlineParser::parse(const std::string &arg, const char** vals, const int vals_size){
65-
std::cout << "TEST ARG " << arg << "==" << m_arg << std::endl;
6666
if(arg == m_arg){
67-
std::cout << "TEST ARG FOUND ARG " << arg << std::endl;
6867
if(vals_size < 1) return -1;
6968

7069
try{

src/core/pserver/PSshardProvenanceDBclient.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace chimbuko;
88

99
PSshardProvenanceDBclient::~PSshardProvenanceDBclient(){ disconnect(); } //call disconnect in derived class to ensure derived class is still alive when disconnect is called
1010

11-
void PSshardProvenanceDBclient::connectShard(const std::string &addr_file_dir, int shard, int nshards, int ninstances){
11+
void PSshardProvenanceDBclient::connectShard(int shard, const std::string &addr_file_dir, const int nshards, const int ninstances){
1212
ProvDBsetup setup(nshards, ninstances);
1313

1414
int instance = setup.getShardInstance(shard);

src/modules/performance_analysis/provdb/ProvDBprune.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include<chimbuko/modules/performance_analysis/provdb/ProvDBprune.hpp>
22
#include<chimbuko/modules/performance_analysis/provdb/ProvDBpruneInterface.hpp>
33
#include<chimbuko/core/util/error.hpp>
4+
#include<chimbuko/core/verbose.hpp>
45
#include<limits>
56

67
using namespace chimbuko;
@@ -10,11 +11,13 @@ void ProvDBprune::pruneImplementation(ADOutlier &ad, sonata::Database &db){
1011
//Prune the outliers and update scores / model on remaining. Also gather new anomaly statistics
1112
{
1213
ProvDBpruneInterface po(ad, db, ADDataInterface::EventType::Outlier, &m_anom_metrics);
14+
progressStream << "Pruning normal events from 'anomalies' database with " << po.nDataSets() << " function indices" << std::endl;
1315
ad.run(po);
1416
}
1517
//Prune normal execs and update scores / model on remaining
1618
{
1719
ProvDBpruneInterface pn(ad, db, ADDataInterface::EventType::Normal);
20+
progressStream << "Pruning outlier events from 'normal_execs' database with " << pn.nDataSets() << " function indices" << std::endl;
1821
ad.run(pn);
1922
}
2023
}

src/modules/performance_analysis/provdb/ProvDBpruneInterface.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ while( ($rec = db_fetch(')" + coll + R"(')) != NULL ){
2626

2727
nlohmann::json rj = nlohmann::json::parse(r->second);
2828
if(!rj.is_array()) fatal_error("Expected an array type");
29-
30-
for(int i=0;i<rj.size();i++)
31-
std::cout << i << " " << rj[i][0] << " " << rj[i][1] << " " << rj[i][2] << std::endl;
32-
29+
3330
for(auto const &e: rj){
3431
uint64_t id = e[0].template get<uint64_t>();
3532
unsigned long fid = e[1].template get<unsigned long>();
@@ -54,6 +51,7 @@ void ProvDBpruneInterface::recordDataSetLabelsInternal(const std::vector<Elem> &
5451
std::vector<uint64_t> to_update;
5552
std::vector<double> update_scores;
5653

54+
int fid = this->getDataSetModelIndex(dset_index);
5755
ADDataInterface::EventType type_to_remove = m_prune_type == ADDataInterface::EventType::Outlier ? ADDataInterface::EventType::Normal : ADDataInterface::EventType::Outlier;
5856

5957
for(auto const &e: data){ //we used the record id as the index, which is unique
@@ -65,12 +63,12 @@ void ProvDBpruneInterface::recordDataSetLabelsInternal(const std::vector<Elem> &
6563
update_scores.push_back(e.score);
6664
}
6765
}
68-
progressStream << "Pruning " << to_prune.size() << " of " << data.size() << " records in shard" << std::endl;
69-
66+
progressStream << "Pruning " << to_prune.size() << " of " << data.size() << " records in shard for function index " << fid << std::endl;
67+
7068
m_collection->erase_multi(to_prune.data(), to_prune.size(), true); //last entry tells database to commit change
7169

7270
//Grab records to update in batches
73-
progressStream << "Updating scores and models for " << to_update.size() << " records" << std::endl;
71+
progressStream << "Updating scores and models for " << to_update.size() << " records with function index " << fid << std::endl;
7472
std::unordered_map<uint64_t, nlohmann::json> json_param_cache;
7573
batchAmendRecords(*m_collection, to_update, [&](nlohmann::json &rec, size_t i){
7674
//Update score

test/run_ad.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@ mkdir -p temp perf
2121

2222
export CHIMBUKO_DISABLE_CUDA_JIT_WORKAROUND=1
2323

24-
mpirun --allow-run-as-root --oversubscribe -n 1 ${appdir}/pserver performance_analysis -logdir "./perf/" -ad "sstd" -model_force_update 1 &
24+
cat <<EOF > /tmp/algo_params.json
25+
{
26+
"algorithm" : "sstd",
27+
"sstd_sigma" : 6.0
28+
}
29+
EOF
30+
31+
cat /tmp/algo_params.json
32+
echo "Starting pserver"
33+
mpirun --allow-run-as-root --oversubscribe -n 1 ${appdir}/pserver performance_analysis -logdir "./perf/" -algo_params_file "/tmp/algo_params.json" -model_force_update 1 &
2534
ps_wid=$!
2635

2736
sleep 5

test/unit_tests/modules/performance_analysis/provdb/ProvDBprune.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TEST(TestProvDBprune, works){
2222
std::vector<std::unique_ptr<PSshardProvenanceDBclient> > shard_clients(nshards);
2323
for(int i=0;i<nshards;i++){
2424
shard_clients[i].reset(new PSshardProvenanceDBclient(setup.getMainDBcollections()));
25-
shard_clients[i]->connectShard("/tmp",i,nshards,1);
25+
shard_clients[i]->connectShard(i,"/tmp",nshards,1);
2626
}
2727

2828
PSglobalProvenanceDBclient glob_client(setup.getGlobalDBcollections());

0 commit comments

Comments
 (0)