Skip to content

Commit 0396f36

Browse files
committed
add parallel threads configuration in candle library. use example on p1b1
1 parent b612422 commit 0396f36

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

Pilot1/P1B1/p1b1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import candle_keras as candle
2121

2222
logger = logging.getLogger(__name__)
23+
candle.set_parallelism_threads()
2324

2425
additional_definitions = [
2526
{'name':'latent_dim',

common/candle_keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras_utils import build_initializer
2727
from keras_utils import build_optimizer
2828
from keras_utils import set_seed
29+
from keras_utils import set_parallelism_threads
2930

3031
from solr_keras import CandleRemoteMonitor, compute_trainable_params, TerminateOnTimeOut
3132

common/keras_utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,24 @@
1616
warnings.filterwarnings("ignore", category=DeprecationWarning)
1717
from sklearn.metrics import r2_score
1818

19+
import os
20+
def set_parallelism_threads():
21+
if K.backend() == 'tensorflow' and 'NUM_INTRA_THREADS' in os.environ and 'NUM_INTER_THREADS' in os.environ:
22+
import tensorflow as tf
23+
# print('Using Thread Parallelism: {} NUM_INTRA_THREADS, {} NUM_INTER_THREADS'.format(os.environ['NUM_INTRA_THREADS'], os.environ['NUM_INTER_THREADS']))
24+
session_conf = tf.ConfigProto(inter_op_parallelism_threads=int(os.environ['NUM_INTER_THREADS']),
25+
intra_op_parallelism_threads=int(os.environ['NUM_INTRA_THREADS']))
26+
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
27+
K.set_session(sess)
28+
29+
1930

2031
def set_seed(seed):
2132
set_seed_defaultUtils(seed)
2233

2334
if K.backend() == 'tensorflow':
2435
import tensorflow as tf
2536
tf.set_random_seed(seed)
26-
# session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
27-
# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
28-
# K.set_session(sess)
29-
30-
# Uncommit when running on an optimized tensorflow where NUM_INTER_THREADS and
31-
# NUM_INTRA_THREADS env vars are set.
32-
# session_conf = tf.ConfigProto(inter_op_parallelism_threads=int(os.environ['NUM_INTER_THREADS']),
33-
# intra_op_parallelism_threads=int(os.environ['NUM_INTRA_THREADS']))
34-
# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
35-
# K.set_session(sess)
3637

3738

3839
def get_function(name):

0 commit comments

Comments
 (0)