|
1 | 1 | # Setup
|
2 | 2 |
|
3 | 3 | import os
|
4 |
| -import sys |
5 |
| -# import gzip |
6 |
| - |
7 |
| -# import math |
8 |
| -# import matplotlib |
9 |
| -# matplotlib.use('Agg') |
10 |
| - |
11 |
| -# import matplotlib.pyplot as plt |
12 | 4 |
|
13 | 5 | from tensorflow.keras import backend as K
|
14 |
| -import tensorflow.keras.optimizers as optimizers |
15 | 6 | from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
|
16 | 7 |
|
17 | 8 | file_path = os.path.dirname(os.path.realpath(__file__))
|
18 |
| -lib_path = os.path.abspath(os.path.join(file_path, '..', '..', 'common')) |
19 |
| -sys.path.append(lib_path) |
20 | 9 |
|
21 | 10 | import candle
|
22 | 11 | import smiles_transformer as st
|
@@ -52,20 +41,24 @@ def run(params):
|
52 | 41 |
|
53 | 42 | model = st.transformer_model(params)
|
54 | 43 |
|
55 |
| - optimizer = optimizers.deserialize({'class_name': params['optimizer'], 'config': {}}) |
| 44 | + kerasDefaults = candle.keras_default_config() |
| 45 | + |
| 46 | + optimizer = candle.build_optimizer(params['optimizer'], params['learning_rate'], kerasDefaults) |
| 47 | + |
| 48 | + # optimizer = optimizers.deserialize({'class_name': params['optimizer'], 'config': {}}) |
56 | 49 |
|
57 | 50 | # I don't know why we set base_lr. It doesn't appear to be used.
|
58 | 51 | # if 'base_lr' in params and params['base_lr'] > 0:
|
59 | 52 | # base_lr = params['base_lr']
|
60 | 53 | # else:
|
61 | 54 | # base_lr = K.get_value(optimizer.lr)
|
62 | 55 |
|
63 |
| - if 'learning_rate' in params and params['learning_rate'] > 0: |
64 |
| - K.set_value(optimizer.lr, params['learning_rate']) |
65 |
| - print('Done setting optimizer {} learning rate to {}'.format( |
66 |
| - params['optimizer'], params['learning_rate'])) |
| 56 | + # if 'learning_rate' in params and params['learning_rate'] > 0: |
| 57 | + # K.set_value(optimizer.lr, params['learning_rate']) |
| 58 | + # print('Done setting optimizer {} learning rate to {}'.format( |
| 59 | + # params['optimizer'], params['learning_rate'])) |
67 | 60 |
|
68 |
| - model.compile(loss='mean_squared_error', |
| 61 | + model.compile(loss=params['loss'], |
69 | 62 | optimizer=optimizer,
|
70 | 63 | metrics=['mae', st.r2])
|
71 | 64 |
|
|
0 commit comments