File tree Expand file tree Collapse file tree 3 files changed +13
-10
lines changed Expand file tree Collapse file tree 3 files changed +13
-10
lines changed Original file line number Diff line number Diff line change 20
20
import candle_keras as candle
21
21
22
22
logger = logging .getLogger (__name__ )
23
+ candle .set_parallelism_threads ()
23
24
24
25
additional_definitions = [
25
26
{'name' :'latent_dim' ,
Original file line number Diff line number Diff line change 26
26
from keras_utils import build_initializer
27
27
from keras_utils import build_optimizer
28
28
from keras_utils import set_seed
29
+ from keras_utils import set_parallelism_threads
29
30
30
31
from solr_keras import CandleRemoteMonitor , compute_trainable_params , TerminateOnTimeOut
31
32
Original file line number Diff line number Diff line change 16
16
warnings .filterwarnings ("ignore" , category = DeprecationWarning )
17
17
from sklearn .metrics import r2_score
18
18
19
+ import os
20
+ def set_parallelism_threads ():
21
+ if K .backend () == 'tensorflow' and 'NUM_INTRA_THREADS' in os .environ and 'NUM_INTER_THREADS' in os .environ :
22
+ import tensorflow as tf
23
+ # print('Using Thread Parallelism: {} NUM_INTRA_THREADS, {} NUM_INTER_THREADS'.format(os.environ['NUM_INTRA_THREADS'], os.environ['NUM_INTER_THREADS']))
24
+ session_conf = tf .ConfigProto (inter_op_parallelism_threads = int (os .environ ['NUM_INTER_THREADS' ]),
25
+ intra_op_parallelism_threads = int (os .environ ['NUM_INTRA_THREADS' ]))
26
+ sess = tf .Session (graph = tf .get_default_graph (), config = session_conf )
27
+ K .set_session (sess )
28
+
29
+
19
30
20
31
def set_seed (seed ):
21
32
set_seed_defaultUtils (seed )
22
33
23
34
if K .backend () == 'tensorflow' :
24
35
import tensorflow as tf
25
36
tf .set_random_seed (seed )
26
- # session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
27
- # sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
28
- # K.set_session(sess)
29
-
30
- # Uncommit when running on an optimized tensorflow where NUM_INTER_THREADS and
31
- # NUM_INTRA_THREADS env vars are set.
32
- # session_conf = tf.ConfigProto(inter_op_parallelism_threads=int(os.environ['NUM_INTER_THREADS']),
33
- # intra_op_parallelism_threads=int(os.environ['NUM_INTRA_THREADS']))
34
- # sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
35
- # K.set_session(sess)
36
37
37
38
38
39
def get_function (name ):
You can’t perform that action at this time.
0 commit comments