@@ -408,17 +408,44 @@ run_model (ENERGYTYPE & dener,
408408 nnpmap.backward (datom_virial_.begin (), datom_virial.begin (), 9 );
409409}
410410
411+ static void
412+ get_env_nthreads (int & num_intra_nthreads,
413+ int & num_inter_nthreads)
414+ {
415+ num_intra_nthreads = 0 ;
416+ num_inter_nthreads = 0 ;
417+ const char * env_intra_nthreads = std::getenv (" OMP_NUM_THREADS" );
418+ const char * env_inter_nthreads = std::getenv (" TF_INTER_OP_PARALLELISM_THREADS" );
419+ if (env_intra_nthreads &&
420+ string (env_intra_nthreads) != string (" " ) &&
421+ atoi (env_intra_nthreads) >= 0
422+ ) {
423+ num_intra_nthreads = atoi (env_intra_nthreads);
424+ }
425+ if (env_inter_nthreads &&
426+ string (env_inter_nthreads) != string (" " ) &&
427+ atoi (env_inter_nthreads) >= 0
428+ ) {
429+ num_inter_nthreads = atoi (env_inter_nthreads);
430+ }
431+ }
432+
411433
412434NNPInter::
413435NNPInter ()
414436 : inited (false )
415437{
438+ get_env_nthreads (num_intra_nthreads, num_inter_nthreads);
416439}
417440
418441NNPInter::
419442NNPInter (const string & model)
420443{
421- checkStatus (NewSession (SessionOptions (), &session));
444+ get_env_nthreads (num_intra_nthreads, num_inter_nthreads);
445+ SessionOptions options;
446+ options.config .set_inter_op_parallelism_threads (num_inter_nthreads);
447+ options.config .set_intra_op_parallelism_threads (num_intra_nthreads);
448+ checkStatus (NewSession (options, &session));
422449 checkStatus (ReadBinaryProto (Env::Default (), model, &graph_def));
423450 checkStatus (session->Create (graph_def));
424451 rcut = get_rcut ();
@@ -432,7 +459,10 @@ NNPInter::
432459init (const string & model)
433460{
434461 assert (!inited);
435- checkStatus (NewSession (SessionOptions (), &session));
462+ SessionOptions options;
463+ options.config .set_inter_op_parallelism_threads (num_inter_nthreads);
464+ options.config .set_intra_op_parallelism_threads (num_intra_nthreads);
465+ checkStatus (NewSession (options, &session));
436466 checkStatus (ReadBinaryProto (Env::Default (), model, &graph_def));
437467 checkStatus (session->Create (graph_def));
438468 rcut = get_rcut ();
@@ -453,6 +483,8 @@ print_summary(const string &pre) const
453483 cout << pre << " build float prec: " + global_float_prec << endl;
454484 cout << pre << " build with tf inc: " + global_tf_include_dir << endl;
455485 cout << pre << " build with tf lib: " + global_tf_lib << endl;
486+ cout << pre << " set tf intra_op_parallelism_threads: " << num_intra_nthreads << endl;
487+ cout << pre << " set tf inter_op_parallelism_threads: " << num_inter_nthreads << endl;
456488}
457489
458490
@@ -592,16 +624,21 @@ NNPInterModelDevi ()
592624 : inited (false ),
593625 numb_models (0 )
594626{
627+ get_env_nthreads (num_intra_nthreads, num_inter_nthreads);
595628}
596629
597630NNPInterModelDevi::
598631NNPInterModelDevi (const vector<string> & models)
599632{
633+ get_env_nthreads (num_intra_nthreads, num_inter_nthreads);
600634 numb_models = models.size ();
601635 sessions.resize (numb_models);
602636 graph_defs.resize (numb_models);
637+ SessionOptions options;
638+ options.config .set_inter_op_parallelism_threads (num_inter_nthreads);
639+ options.config .set_intra_op_parallelism_threads (num_intra_nthreads);
603640 for (unsigned ii = 0 ; ii < numb_models; ++ii){
604- checkStatus (NewSession (SessionOptions () , &(sessions[ii])));
641+ checkStatus (NewSession (options , &(sessions[ii])));
605642 checkStatus (ReadBinaryProto (Env::Default (), models[ii], &graph_defs[ii]));
606643 checkStatus (sessions[ii]->Create (graph_defs[ii]));
607644 }
@@ -619,8 +656,11 @@ init (const vector<string> & models)
619656 numb_models = models.size ();
620657 sessions.resize (numb_models);
621658 graph_defs.resize (numb_models);
659+ SessionOptions options;
660+ options.config .set_inter_op_parallelism_threads (num_inter_nthreads);
661+ options.config .set_intra_op_parallelism_threads (num_intra_nthreads);
622662 for (unsigned ii = 0 ; ii < numb_models; ++ii){
623- checkStatus (NewSession (SessionOptions () , &(sessions[ii])));
663+ checkStatus (NewSession (options , &(sessions[ii])));
624664 checkStatus (ReadBinaryProto (Env::Default (), models[ii], &graph_defs[ii]));
625665 checkStatus (sessions[ii]->Create (graph_defs[ii]));
626666 }
0 commit comments