Skip to content

Commit 1356721

Browse files
author
brettin
committed
added optimizer param and set_memory_growth
1 parent 155ad7a commit 1356721

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

examples/xform-smiles/srt_baseline_keras.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# import matplotlib.pyplot as plt
1212

1313
from tensorflow.keras import backend as K
14-
from tensorflow.keras.optimizers import Adam # RMSprop, SGD
14+
import tensorflow.keras.optimizers as optimizers
1515
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
1616

1717
file_path = os.path.dirname(os.path.realpath(__file__))
@@ -21,6 +21,15 @@
2121
import candle
2222
import smiles_transformer as st
2323

24+
import tensorflow.config.experimental
25+
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
26+
try:
27+
for gpu in gpus:
28+
print("setting memory growth")
29+
tensorflow.config.experimental.set_memory_growth(gpu, True)
30+
except RuntimeError as e:
31+
print(e)
32+
2433

2534
def initialize_parameters(default_model='regress_default_model.txt'):
2635

@@ -43,8 +52,21 @@ def run(params):
4352

4453
model = st.transformer_model(params)
4554

55+
optimizer = optimizers.deserialize({'class_name': params['optimizer'], 'config': {}})
56+
57+
# I don't know why we set base_lr. It doesn't appear to be used.
58+
if 'base_lr' in params and params['base_lr'] > 0:
59+
base_lr = params['base_lr']
60+
else:
61+
base_lr = K.get_value(optimizer.lr)
62+
63+
if 'learning_rate' in params and params['learning_rate'] > 0:
64+
K.set_value(optimizer.lr, params['learning_rate'])
65+
print('Done setting optimizer {} learning rate to {}'.format(
66+
params['optimizer'],params['learning_rate']))
67+
4668
model.compile(loss='mean_squared_error',
47-
optimizer=Adam(lr=0.00001),
69+
optimizer=optimizer,
4870
metrics=['mae', st.r2])
4971

5072
# set up a bunch of callbacks to do work during model training..

0 commit comments

Comments
 (0)