Skip to content

Commit 1998f6b

Browse files
author
Han Wang
committed
set nthreads in md
1 parent 5caff7d commit 1998f6b

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

source/lib/include/NNPInter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class NNPInter
9393
VALUETYPE cutoff () const {return rcut;};
9494
private:
9595
Session* session;
96+
int num_intra_nthreads, num_inter_nthreads;
9697
GraphDef graph_def;
9798
bool inited;
9899
VALUETYPE get_rcut () const;
@@ -152,6 +153,7 @@ class NNPInterModelDevi
152153
private:
153154
unsigned numb_models;
154155
vector<Session*> sessions;
156+
int num_intra_nthreads, num_inter_nthreads;
155157
vector<GraphDef> graph_defs;
156158
bool inited;
157159
VALUETYPE get_rcut () const;

source/lib/src/NNPInter.cc

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,17 +408,44 @@ run_model (ENERGYTYPE & dener,
408408
nnpmap.backward (datom_virial_.begin(), datom_virial.begin(), 9);
409409
}
410410

411+
static void
412+
get_env_nthreads(int & num_intra_nthreads,
413+
int & num_inter_nthreads)
414+
{
415+
num_intra_nthreads = 0;
416+
num_inter_nthreads = 0;
417+
const char* env_intra_nthreads = std::getenv("OMP_NUM_THREADS");
418+
const char* env_inter_nthreads = std::getenv("TF_INTER_OP_PARALLELISM_THREADS");
419+
if (env_intra_nthreads &&
420+
string(env_intra_nthreads) != string("") &&
421+
atoi(env_intra_nthreads) >= 0
422+
) {
423+
num_intra_nthreads = atoi(env_intra_nthreads);
424+
}
425+
if (env_inter_nthreads &&
426+
string(env_inter_nthreads) != string("") &&
427+
atoi(env_inter_nthreads) >= 0
428+
) {
429+
num_inter_nthreads = atoi(env_inter_nthreads);
430+
}
431+
}
432+
411433

412434
NNPInter::
413435
NNPInter ()
414436
: inited (false)
415437
{
438+
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
416439
}
417440

418441
NNPInter::
419442
NNPInter (const string & model)
420443
{
421-
checkStatus (NewSession(SessionOptions(), &session));
444+
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
445+
SessionOptions options;
446+
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
447+
options.config.set_intra_op_parallelism_threads(num_intra_nthreads);
448+
checkStatus (NewSession(options, &session));
422449
checkStatus (ReadBinaryProto(Env::Default(), model, &graph_def));
423450
checkStatus (session->Create(graph_def));
424451
rcut = get_rcut();
@@ -432,7 +459,10 @@ NNPInter::
432459
init (const string & model)
433460
{
434461
assert (!inited);
435-
checkStatus (NewSession(SessionOptions(), &session));
462+
SessionOptions options;
463+
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
464+
options.config.set_intra_op_parallelism_threads(num_intra_nthreads);
465+
checkStatus (NewSession(options, &session));
436466
checkStatus (ReadBinaryProto(Env::Default(), model, &graph_def));
437467
checkStatus (session->Create(graph_def));
438468
rcut = get_rcut();
@@ -453,6 +483,8 @@ print_summary(const string &pre) const
453483
cout << pre << "build float prec: " + global_float_prec << endl;
454484
cout << pre << "build with tf inc: " + global_tf_include_dir << endl;
455485
cout << pre << "build with tf lib: " + global_tf_lib << endl;
486+
cout << pre << "set tf intra_op_parallelism_threads: " << num_intra_nthreads << endl;
487+
cout << pre << "set tf inter_op_parallelism_threads: " << num_inter_nthreads << endl;
456488
}
457489

458490

@@ -592,16 +624,21 @@ NNPInterModelDevi ()
592624
: inited (false),
593625
numb_models (0)
594626
{
627+
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
595628
}
596629

597630
NNPInterModelDevi::
598631
NNPInterModelDevi (const vector<string> & models)
599632
{
633+
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
600634
numb_models = models.size();
601635
sessions.resize(numb_models);
602636
graph_defs.resize(numb_models);
637+
SessionOptions options;
638+
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
639+
options.config.set_intra_op_parallelism_threads(num_intra_nthreads);
603640
for (unsigned ii = 0; ii < numb_models; ++ii){
604-
checkStatus (NewSession(SessionOptions(), &(sessions[ii])));
641+
checkStatus (NewSession(options, &(sessions[ii])));
605642
checkStatus (ReadBinaryProto(Env::Default(), models[ii], &graph_defs[ii]));
606643
checkStatus (sessions[ii]->Create(graph_defs[ii]));
607644
}
@@ -619,8 +656,11 @@ init (const vector<string> & models)
619656
numb_models = models.size();
620657
sessions.resize(numb_models);
621658
graph_defs.resize(numb_models);
659+
SessionOptions options;
660+
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
661+
options.config.set_intra_op_parallelism_threads(num_intra_nthreads);
622662
for (unsigned ii = 0; ii < numb_models; ++ii){
623-
checkStatus (NewSession(SessionOptions(), &(sessions[ii])));
663+
checkStatus (NewSession(options, &(sessions[ii])));
624664
checkStatus (ReadBinaryProto(Env::Default(), models[ii], &graph_defs[ii]));
625665
checkStatus (sessions[ii]->Create(graph_defs[ii]));
626666
}

0 commit comments

Comments
 (0)