Skip to content

Commit 9682111

Browse files
author
Han Wang
committed
use envs TF_INTRA_OP_PARALLELISM_THREADS and TF_INTER_OP_PARALLELISM_THREADS to control the multi-threading of tf, clean up runoptions.
1 parent 0ffb5bf commit 9682111

File tree

10 files changed

+55
-47
lines changed

10 files changed

+55
-47
lines changed

source/lib/src/common.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ get_env_nthreads(int & num_intra_nthreads,
159159
{
160160
num_intra_nthreads = 0;
161161
num_inter_nthreads = 0;
162-
const char* env_intra_nthreads = std::getenv("OMP_NUM_THREADS");
162+
const char* env_intra_nthreads = std::getenv("TF_INTRA_OP_PARALLELISM_THREADS");
163163
const char* env_inter_nthreads = std::getenv("TF_INTER_OP_PARALLELISM_THREADS");
164164
if (env_intra_nthreads &&
165165
string(env_intra_nthreads) != string("") &&

source/train/DeepEval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import numpy as np
55

66
from deepmd.env import tf
7+
from deepmd.env import default_tf_session_config
78
from deepmd.common import make_default_mesh
89

9-
1010
class DeepEval():
1111
"""
1212
common methods for DeepPot, DeepWFC, DeepPolar, ...
@@ -16,7 +16,7 @@ def __init__(self,
1616
load_prefix = 'load') :
1717
self.graph = self._load_graph (model_file, prefix = load_prefix)
1818
t_mt = self.graph.get_tensor_by_name(os.path.join(load_prefix, 'model_attr/model_type:0'))
19-
sess = tf.Session (graph = self.graph)
19+
sess = tf.Session (graph = self.graph, config=default_tf_session_config)
2020
[mt] = sess.run([t_mt], feed_dict = {})
2121
self.model_type = mt.decode('utf-8')
2222

@@ -122,7 +122,7 @@ def __init__(self,
122122
# outputs
123123
self.t_tensor = self.graph.get_tensor_by_name (os.path.join(load_prefix, 'o_%s:0' % self.variable_name))
124124
# start a tf session associated to the graph
125-
self.sess = tf.Session (graph = self.graph)
125+
self.sess = tf.Session (graph = self.graph, config=default_tf_session_config)
126126
[self.ntypes, self.rcut, self.tmap, self.tselt] = self.sess.run([self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type])
127127
self.tmap = self.tmap.decode('UTF-8').split()
128128

source/train/DeepPot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from deepmd.env import tf
5+
from deepmd.env import default_tf_session_config
56
from deepmd.common import make_default_mesh
67
from deepmd.DeepEval import DeepEval
78
from deepmd.DataModifier import DipoleChargeModifier
@@ -43,7 +44,7 @@ def __init__(self,
4344
self.t_aparam = self.graph.get_tensor_by_name ('load/t_aparam:0')
4445
self.has_aparam = self.t_aparam is not None
4546
# start a tf session associated to the graph
46-
self.sess = tf.Session (graph = self.graph)
47+
self.sess = tf.Session (graph = self.graph, config=default_tf_session_config)
4748
[self.ntypes, self.rcut, self.dfparam, self.daparam, self.tmap] = self.sess.run([self.t_ntypes, self.t_rcut, self.t_dfparam, self.t_daparam, self.t_tmap])
4849
self.tmap = self.tmap.decode('UTF-8').split()
4950
# setup modifier

source/train/DescrptLocFrame.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptLocFrame () :
910
def __init__(self, jdata):
@@ -55,7 +56,7 @@ def __init__(self, jdata):
5556
sel_a = self.sel_a,
5657
sel_r = self.sel_r,
5758
axis_rule = self.axis_rule)
58-
self.sub_sess = tf.Session(graph = sub_graph)
59+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
5960

6061

6162
def get_rcut (self) :

source/train/DescrptSeA.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptSeA ():
910
def __init__ (self, jdata):
@@ -66,7 +67,7 @@ def __init__ (self, jdata):
6667
rcut_r_smth = self.rcut_r_smth,
6768
sel_a = self.sel_a,
6869
sel_r = self.sel_r)
69-
self.sub_sess = tf.Session(graph = sub_graph)
70+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
7071

7172

7273
def get_rcut (self) :

source/train/DescrptSeR.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptSeR ():
910
def __init__ (self, jdata):
@@ -60,7 +61,7 @@ def __init__ (self, jdata):
6061
rcut = self.rcut,
6162
rcut_smth = self.rcut_smth,
6263
sel = self.sel_r)
63-
self.sub_sess = tf.Session(graph = sub_graph)
64+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
6465

6566

6667
def get_rcut (self) :

source/train/EwaldRecp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from deepmd.RunOptions import global_cvt_2_tf_float
88
from deepmd.RunOptions import global_cvt_2_ener_float
99
from deepmd.env import op_module
10+
from deepmd.env import default_tf_session_config
1011

1112
class EwaldRecp () :
1213
def __init__(self,
@@ -25,7 +26,7 @@ def __init__(self,
2526
= op_module.ewald_recp(self.t_coord, self.t_charge, self.t_nloc, self.t_box,
2627
ewald_h = self.hh,
2728
ewald_beta = self.beta)
28-
self.sess = tf.Session(graph=graph)
29+
self.sess = tf.Session(graph=graph, config=default_tf_session_config)
2930

3031
def eval(self,
3132
coord,

source/train/RunOptions.py.in

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os,sys
2-
import tensorflow as tf
2+
from deepmd.env import tf
3+
from deepmd.env import get_tf_default_nthreads
34
import numpy as np
45
import deepmd.cluster.Slurm as Slurm
56
import deepmd.cluster.Local as Local
@@ -28,14 +29,6 @@ global_git_branch='@GIT_BRANCH@'
2829
global_tf_include_dir='@TensorFlow_INCLUDE_DIRS@'
2930
global_tf_libs='@TensorFlow_LIBRARY@'
3031

31-
def _get_threads_env () :
32-
num_intra_threads = None
33-
if 'OMP_NUM_THREADS' in os.environ :
34-
num_intra_threads = int(os.environ['OMP_NUM_THREADS'])
35-
else :
36-
num_intra_threads = 0
37-
return num_intra_threads
38-
3932
def _is_slurm() :
4033
return "SLURM_JOB_NODELIST" in os.environ
4134

@@ -106,10 +99,6 @@ class RunOptions (object) :
10699
def __init__ (self,
107100
args,
108101
try_distrib = False):
109-
# thread settings
110-
self.num_intra_threads = _get_threads_env()
111-
self.num_inter_threads = 0
112-
113102
# distributed tasks
114103
if try_distrib :
115104
self._try_init_mpi()
@@ -132,8 +121,6 @@ class RunOptions (object) :
132121
if args.restart is not None:
133122
self.restart = os.path.abspath(args.restart)
134123
self.init_mode = "restart"
135-
if args.inter_threads is not None :
136-
self.num_inter_threads = args.inter_threads
137124

138125
def message (self, msg) :
139126
if self.verbose :
@@ -167,28 +154,32 @@ class RunOptions (object) :
167154
def print_summary(self) :
168155
msg = ""
169156
msg += "---Summary of the training---------------------------------------\n"
170-
msg += 'installed to: %s\n' % global_install_prefix
171-
msg += 'source : %s\n' % global_git_summ
172-
msg += 'source brach: %s\n' % global_git_branch
173-
msg += 'source commit: %s\n' % global_git_hash
174-
msg += 'source commit at: %s\n' % global_git_date
175-
msg += 'build float prec: %s\n' % global_float_prec
176-
msg += 'build with tf inc: %s\n' % global_tf_include_dir
157+
msg += 'installed to: %s\n' % global_install_prefix
158+
msg += 'source : %s\n' % global_git_summ
159+
msg += 'source brach: %s\n' % global_git_branch
160+
msg += 'source commit: %s\n' % global_git_hash
161+
msg += 'source commit at: %s\n' % global_git_date
162+
msg += 'build float prec: %s\n' % global_float_prec
163+
msg += 'build with tf inc: %s\n' % global_tf_include_dir
177164
for idx,ii in enumerate(global_tf_libs.split(';')) :
178165
if idx == 0 :
179-
msg += 'build with tf lib: %s\n' % ii
166+
msg += 'build with tf lib: %s\n' % ii
180167
else :
181-
msg += ' %s\n' % ii
168+
msg += ' %s\n' % ii
182169
if self.is_distrib:
183170
msg += "distributed\n"
184-
msg += "ps list: %s\n" % str(self.cluster['ps'])
185-
msg += "worker list: %s\n" % str(self.cluster['worker'])
186-
msg += "chief on: %s\n" % self.nodename
171+
msg += "ps list: %s\n" % str(self.cluster['ps'])
172+
msg += "worker list: %s\n" % str(self.cluster['worker'])
173+
msg += "chief on: %s\n" % self.nodename
187174
else :
188-
msg += "running on: %s\n" % self.nodename
189-
msg += "gpu per node: %s\n" % self.gpus
190-
msg += "num_inter_threads: %d\n" % self.num_inter_threads
191-
msg += "num_intra_threads: %d\n" % self.num_intra_threads
175+
msg += "running on: %s\n" % self.nodename
176+
if self.gpus is None:
177+
msg += "CUDA_VISIBLE_DEVICES: unset\n"
178+
else:
179+
msg += "CUDA_VISIBLE_DEVICES: %s\n" % self.gpus
180+
intra, inter = get_tf_default_nthreads()
181+
msg += "num_intra_threads: %d\n" % intra
182+
msg += "num_inter_threads: %d\n" % inter
192183
msg += "-----------------------------------------------------------------\n"
193184
self.message(msg)
194185

source/train/Trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import numpy as np
66
from deepmd.env import tf
7+
from deepmd.env import default_tf_session_config
78
from deepmd.RunOptions import global_tf_float_precision
89
from deepmd.RunOptions import global_ener_float_precision
910
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA
@@ -288,10 +289,7 @@ def _build_training(self):
288289
self._message("built training")
289290

290291
def _init_sess_serial(self) :
291-
self.sess = tf.Session(
292-
config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
293-
inter_op_parallelism_threads=self.run_opt.num_inter_threads
294-
))
292+
self.sess = tf.Session(config=default_tf_session_config)
295293
self.saver = tf.train.Saver()
296294
saver = self.saver
297295
if self.run_opt.init_mode == 'init_from_scratch' :

source/train/env.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
except ImportError:
1212
import tensorflow as tf
1313

14-
def set_env_if_empty(key, value):
14+
def set_env_if_empty(key, value, verbose=True):
1515
if os.environ.get(key) is None:
1616
os.environ[key] = value
17-
logging.warn("Environment variable {} is empty. Use the default value {}".format(key, value))
17+
if verbose:
18+
logging.warn("Environment variable {} is empty. Use the default value {}".format(key, value))
1819

1920
def set_mkl():
2021
"""Tuning MKL for the best performance
@@ -32,6 +33,18 @@ def set_mkl():
3233
set_env_if_empty("KMP_AFFINITY", "granularity=fine,verbose,compact,1,0")
3334
reload(np)
3435

36+
def set_tf_default_nthreads():
37+
set_env_if_empty("TF_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False)
38+
set_env_if_empty("TF_INTER_OP_PARALLELISM_THREADS", "0", verbose=False)
39+
40+
def get_tf_default_nthreads():
41+
return int(os.environ.get('TF_INTRA_OP_PARALLELISM_THREADS')), int(os.environ.get('TF_INTER_OP_PARALLELISM_THREADS'))
42+
43+
def get_tf_session_config():
44+
set_tf_default_nthreads()
45+
intra, inter = get_tf_default_nthreads()
46+
return tf.ConfigProto(intra_op_parallelism_threads=intra, inter_op_parallelism_threads=inter)
47+
3548
def get_module(module_name):
3649
"""Load force module."""
3750
if platform.system() == "Windows":
@@ -46,4 +59,5 @@ def get_module(module_name):
4659
return module
4760

4861
op_module = get_module("libop_abi")
49-
op_grads_module = get_module("libop_grads")
62+
op_grads_module = get_module("libop_grads")
63+
default_tf_session_config = get_tf_session_config()

0 commit comments

Comments
 (0)