Skip to content

Commit be45422

Browse files
authored
Merge pull request #156 from amcadmus/devel
control the behavior of tensorflow multithreading
2 parents 4615af8 + 140762f commit be45422

File tree

12 files changed

+64
-54
lines changed

12 files changed

+64
-54
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,6 @@ positional arguments:
437437
438438
optional arguments:
439439
-h, --help show this help message and exit
440-
-t INTER_THREADS, --inter-threads INTER_THREADS
441-
With default value 0. Setting the "inter_op_parallelism_threads" key for the tensorflow, the "intra_op_parallelism_threads" will be set by the env variable OMP_NUM_THREADS
442440
--init-model INIT_MODEL
443441
Initialize a model by the provided checkpoint
444442
--restart RESTART Restart the training from the provided checkpoint
@@ -449,6 +447,15 @@ The keys `intra_op_parallelism_threads` and `inter_op_parallelism_threads` are T
449447

450448
**`--restart model.ckpt`**, continues the training from the checkpoint `model.ckpt`.
451449

450+
On some resources limited machines, one may want to control the number of threads used by DeePMD-kit. This is achieved by three environmental variables: `OMP_NUM_THREADS`, `TF_INTRA_OP_PARALLELISM_THREADS` and `TF_INTER_OP_PARALLELISM_THREADS`. `OMP_NUM_THREADS` controls the multithreading of DeePMD-kit implemented operations. `TF_INTRA_OP_PARALLELISM_THREADS` and `TF_INTER_OP_PARALLELISM_THREADS` controls `intra_op_parallelism_threads` and `inter_op_parallelism_threads`, which are Tensorflow configurations for multithreading. An explanation is found [here](https://stackoverflow.com/questions/41233635/meaning-of-inter-op-parallelism-threads-and-intra-op-parallelism-threads).
451+
452+
For example if you wish to use 3 cores of 2 CPUs on one node, you may set the environmental variables and run DeePMD-kit as follows:
453+
```bash
454+
export OMP_NUM_THREADS=6
455+
export TF_INTRA_OP_PARALLELISM_THREADS=3
456+
export TF_INTER_OP_PARALLELISM_THREADS=2
457+
dp train input.json
458+
```
452459

453460
## Freeze a model
454461

source/lib/src/common.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ get_env_nthreads(int & num_intra_nthreads,
159159
{
160160
num_intra_nthreads = 0;
161161
num_inter_nthreads = 0;
162-
const char* env_intra_nthreads = std::getenv("OMP_NUM_THREADS");
162+
const char* env_intra_nthreads = std::getenv("TF_INTRA_OP_PARALLELISM_THREADS");
163163
const char* env_inter_nthreads = std::getenv("TF_INTER_OP_PARALLELISM_THREADS");
164164
if (env_intra_nthreads &&
165165
string(env_intra_nthreads) != string("") &&

source/train/DeepEval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import numpy as np
55

66
from deepmd.env import tf
7+
from deepmd.env import default_tf_session_config
78
from deepmd.common import make_default_mesh
89

9-
1010
class DeepEval():
1111
"""
1212
common methods for DeepPot, DeepWFC, DeepPolar, ...
@@ -16,7 +16,7 @@ def __init__(self,
1616
load_prefix = 'load') :
1717
self.graph = self._load_graph (model_file, prefix = load_prefix)
1818
t_mt = self.graph.get_tensor_by_name(os.path.join(load_prefix, 'model_attr/model_type:0'))
19-
sess = tf.Session (graph = self.graph)
19+
sess = tf.Session (graph = self.graph, config=default_tf_session_config)
2020
[mt] = sess.run([t_mt], feed_dict = {})
2121
self.model_type = mt.decode('utf-8')
2222

@@ -122,7 +122,7 @@ def __init__(self,
122122
# outputs
123123
self.t_tensor = self.graph.get_tensor_by_name (os.path.join(load_prefix, 'o_%s:0' % self.variable_name))
124124
# start a tf session associated to the graph
125-
self.sess = tf.Session (graph = self.graph)
125+
self.sess = tf.Session (graph = self.graph, config=default_tf_session_config)
126126
[self.ntypes, self.rcut, self.tmap, self.tselt] = self.sess.run([self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type])
127127
self.tmap = self.tmap.decode('UTF-8').split()
128128

source/train/DeepPot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from deepmd.env import tf
5+
from deepmd.env import default_tf_session_config
56
from deepmd.common import make_default_mesh
67
from deepmd.DeepEval import DeepEval
78
from deepmd.DataModifier import DipoleChargeModifier
@@ -43,7 +44,7 @@ def __init__(self,
4344
self.t_aparam = self.graph.get_tensor_by_name ('load/t_aparam:0')
4445
self.has_aparam = self.t_aparam is not None
4546
# start a tf session associated to the graph
46-
self.sess = tf.Session (graph = self.graph)
47+
self.sess = tf.Session (graph = self.graph, config=default_tf_session_config)
4748
[self.ntypes, self.rcut, self.dfparam, self.daparam, self.tmap] = self.sess.run([self.t_ntypes, self.t_rcut, self.t_dfparam, self.t_daparam, self.t_tmap])
4849
self.tmap = self.tmap.decode('UTF-8').split()
4950
# setup modifier

source/train/DescrptLocFrame.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptLocFrame () :
910
def __init__(self, jdata):
@@ -55,7 +56,7 @@ def __init__(self, jdata):
5556
sel_a = self.sel_a,
5657
sel_r = self.sel_r,
5758
axis_rule = self.axis_rule)
58-
self.sub_sess = tf.Session(graph = sub_graph)
59+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
5960

6061

6162
def get_rcut (self) :

source/train/DescrptSeA.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptSeA ():
910
def __init__ (self, jdata):
@@ -66,7 +67,7 @@ def __init__ (self, jdata):
6667
rcut_r_smth = self.rcut_r_smth,
6768
sel_a = self.sel_a,
6869
sel_r = self.sel_r)
69-
self.sub_sess = tf.Session(graph = sub_graph)
70+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
7071

7172

7273
def get_rcut (self) :

source/train/DescrptSeR.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
7+
from deepmd.env import default_tf_session_config
78

89
class DescrptSeR ():
910
def __init__ (self, jdata):
@@ -60,7 +61,7 @@ def __init__ (self, jdata):
6061
rcut = self.rcut,
6162
rcut_smth = self.rcut_smth,
6263
sel = self.sel_r)
63-
self.sub_sess = tf.Session(graph = sub_graph)
64+
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
6465

6566

6667
def get_rcut (self) :

source/train/EwaldRecp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from deepmd.RunOptions import global_cvt_2_tf_float
88
from deepmd.RunOptions import global_cvt_2_ener_float
99
from deepmd.env import op_module
10+
from deepmd.env import default_tf_session_config
1011

1112
class EwaldRecp () :
1213
def __init__(self,
@@ -25,7 +26,7 @@ def __init__(self,
2526
= op_module.ewald_recp(self.t_coord, self.t_charge, self.t_nloc, self.t_box,
2627
ewald_h = self.hh,
2728
ewald_beta = self.beta)
28-
self.sess = tf.Session(graph=graph)
29+
self.sess = tf.Session(graph=graph, config=default_tf_session_config)
2930

3031
def eval(self,
3132
coord,

source/train/RunOptions.py.in

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os,sys
2-
import tensorflow as tf
2+
from deepmd.env import tf
3+
from deepmd.env import get_tf_default_nthreads
34
import numpy as np
45
import deepmd.cluster.Slurm as Slurm
56
import deepmd.cluster.Local as Local
@@ -28,14 +29,6 @@ global_git_branch='@GIT_BRANCH@'
2829
global_tf_include_dir='@TensorFlow_INCLUDE_DIRS@'
2930
global_tf_libs='@TensorFlow_LIBRARY@'
3031

31-
def _get_threads_env () :
32-
num_intra_threads = None
33-
if 'OMP_NUM_THREADS' in os.environ :
34-
num_intra_threads = int(os.environ['OMP_NUM_THREADS'])
35-
else :
36-
num_intra_threads = 0
37-
return num_intra_threads
38-
3932
def _is_slurm() :
4033
return "SLURM_JOB_NODELIST" in os.environ
4134

@@ -106,10 +99,6 @@ class RunOptions (object) :
10699
def __init__ (self,
107100
args,
108101
try_distrib = False):
109-
# thread settings
110-
self.num_intra_threads = _get_threads_env()
111-
self.num_inter_threads = 0
112-
113102
# distributed tasks
114103
if try_distrib :
115104
self._try_init_mpi()
@@ -132,8 +121,6 @@ class RunOptions (object) :
132121
if args.restart is not None:
133122
self.restart = os.path.abspath(args.restart)
134123
self.init_mode = "restart"
135-
if args.inter_threads is not None :
136-
self.num_inter_threads = args.inter_threads
137124

138125
def message (self, msg) :
139126
if self.verbose :
@@ -167,28 +154,32 @@ class RunOptions (object) :
167154
def print_summary(self) :
168155
msg = ""
169156
msg += "---Summary of the training---------------------------------------\n"
170-
msg += 'installed to: %s\n' % global_install_prefix
171-
msg += 'source : %s\n' % global_git_summ
172-
msg += 'source brach: %s\n' % global_git_branch
173-
msg += 'source commit: %s\n' % global_git_hash
174-
msg += 'source commit at: %s\n' % global_git_date
175-
msg += 'build float prec: %s\n' % global_float_prec
176-
msg += 'build with tf inc: %s\n' % global_tf_include_dir
157+
msg += 'installed to: %s\n' % global_install_prefix
158+
msg += 'source : %s\n' % global_git_summ
159+
msg += 'source brach: %s\n' % global_git_branch
160+
msg += 'source commit: %s\n' % global_git_hash
161+
msg += 'source commit at: %s\n' % global_git_date
162+
msg += 'build float prec: %s\n' % global_float_prec
163+
msg += 'build with tf inc: %s\n' % global_tf_include_dir
177164
for idx,ii in enumerate(global_tf_libs.split(';')) :
178165
if idx == 0 :
179-
msg += 'build with tf lib: %s\n' % ii
166+
msg += 'build with tf lib: %s\n' % ii
180167
else :
181-
msg += ' %s\n' % ii
168+
msg += ' %s\n' % ii
182169
if self.is_distrib:
183170
msg += "distributed\n"
184-
msg += "ps list: %s\n" % str(self.cluster['ps'])
185-
msg += "worker list: %s\n" % str(self.cluster['worker'])
186-
msg += "chief on: %s\n" % self.nodename
171+
msg += "ps list: %s\n" % str(self.cluster['ps'])
172+
msg += "worker list: %s\n" % str(self.cluster['worker'])
173+
msg += "chief on: %s\n" % self.nodename
187174
else :
188-
msg += "running on: %s\n" % self.nodename
189-
msg += "gpu per node: %s\n" % self.gpus
190-
msg += "num_inter_threads: %d\n" % self.num_inter_threads
191-
msg += "num_intra_threads: %d\n" % self.num_intra_threads
175+
msg += "running on: %s\n" % self.nodename
176+
if self.gpus is None:
177+
msg += "CUDA_VISIBLE_DEVICES: unset\n"
178+
else:
179+
msg += "CUDA_VISIBLE_DEVICES: %s\n" % self.gpus
180+
intra, inter = get_tf_default_nthreads()
181+
msg += "num_intra_threads: %d\n" % intra
182+
msg += "num_inter_threads: %d\n" % inter
192183
msg += "-----------------------------------------------------------------\n"
193184
self.message(msg)
194185

source/train/Trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import numpy as np
66
from deepmd.env import tf
7+
from deepmd.env import default_tf_session_config
78
from deepmd.RunOptions import global_tf_float_precision
89
from deepmd.RunOptions import global_ener_float_precision
910
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA
@@ -288,10 +289,7 @@ def _build_training(self):
288289
self._message("built training")
289290

290291
def _init_sess_serial(self) :
291-
self.sess = tf.Session(
292-
config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
293-
inter_op_parallelism_threads=self.run_opt.num_inter_threads
294-
))
292+
self.sess = tf.Session(config=default_tf_session_config)
295293
self.saver = tf.train.Saver()
296294
saver = self.saver
297295
if self.run_opt.init_mode == 'init_from_scratch' :

0 commit comments

Comments
 (0)