Skip to content

Commit 7266a6a

Browse files
committed
Added some logging functionality from the P1Combo model into the
generic_utils and then used them in the P2B1 model to display training progress.
1 parent 9117e7f commit 7266a6a

File tree

4 files changed

+54
-7
lines changed

4 files changed

+54
-7
lines changed

Pilot2/P2B1/p2b1.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,11 @@ def train_ac(self):
328328
#for frame in random.sample(range(len(xt_all)), int(self.sampling_density*len(xt_all))):
329329
for frame in range(len(xt_all)):
330330
history = self.molecular_model.fit(xt_all[frame], yt_all[frame], epochs=1,
331-
batch_size=self.batch_size, callbacks=self.callbacks[:2],
332-
verbose=0)
331+
batch_size=self.batch_size, callbacks=self.callbacks[:2])
333332
frame_loss.append(history.history['loss'])
334333
frame_mse.append(history.history['mean_squared_error'])
335334

336335
if not frame % 20 or self.sampling_density != 1.0:
337-
print ("Frame: {0:d}, Current history:\nLoss: {1:3.5f}\tMSE: {2:3.5f}\n"
338-
.format(frame, history.history['loss'][0], history.history['mean_squared_error'][0]))
339-
340336
# Update weights filed every few frames
341337
self.molecular_model.save_weights(model_weight_file)
342338
self.molecular_encoder.save_weights(encoder_weight_file)

Pilot2/P2B1/p2b1_baseline_keras2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys, os, json
55
import argparse
66
import h5py
7+
import logging
78
try:
89
reload # Python 2.7
910
except NameError:
@@ -28,6 +29,9 @@
2829

2930
HOME = os.environ['HOME']
3031

32+
logger = logging.getLogger(__name__)
33+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
34+
3135
def parse_list(option, opt, value, parser):
3236
setattr(parser.values, option.dest, value.split(','))
3337

@@ -86,6 +90,16 @@ def run(GP):
8690
sys.exit(0)
8791
sys.path.append(GP['home_dir'])
8892

93+
# Setup loggin
94+
args = candle.ArgumentStruct(**GP)
95+
# set_seed(args.rng_seed)
96+
# ext = extension_from_parameters(args)
97+
candle.verify_path(args.save_path)
98+
prefix = args.save_path # + ext
99+
logfile = args.logfile if args.logfile else prefix+'.log'
100+
candle.set_up_logger(logger, logfile, False) #args.verbose
101+
logger.info('Params: {}'.format(GP))
102+
89103
import p2b1 as hf
90104
reload(hf)
91105

@@ -238,9 +252,10 @@ def step_decay(epoch):
238252
history = callbacks.History()
239253
# callbacks=[history,lr_scheduler]
240254

255+
history_logger = candle.LoggingCallback(logger.debug)
241256
candleRemoteMonitor = candle.CandleRemoteMonitor(params=GP)
242257
timeoutMonitor = candle.TerminateOnTimeOut(TIMEOUT)
243-
callbacks = [history, candleRemoteMonitor, timeoutMonitor]
258+
callbacks = [history, history_logger, candleRemoteMonitor, timeoutMonitor]
244259
loss = 0.
245260

246261
#### Save the Model to disk

common/candle_keras/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,10 @@
2828
from keras_utils import set_seed
2929
from keras_utils import set_parallelism_threads
3030

31+
from generic_utils import Progbar
32+
from generic_utils import verify_path
33+
from generic_utils import set_up_logger
34+
from generic_utils import LoggingCallback
35+
3136
from solr_keras import CandleRemoteMonitor, compute_trainable_params, TerminateOnTimeOut
3237

common/generic_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import numpy as np
33
import time
44
import sys
5+
import os
56
import six
67
import marshal
78
import types as python_types
8-
9+
import logging
10+
from keras.callbacks import Callback
911

1012
def get_from_module(identifier, module_params, module_name,
1113
instantiate=False, kwargs=None):
@@ -191,3 +193,32 @@ def display_row(objects, positions):
191193

192194
for objects in rows:
193195
display_row(objects, positions)
196+
197+
198+
def verify_path(path):
199+
folder = os.path.dirname(path)
200+
if folder and not os.path.exists(folder):
201+
os.makedirs(folder)
202+
203+
def set_up_logger(logger, logfile, verbose):
204+
verify_path(logfile)
205+
fh = logging.FileHandler(logfile)
206+
fh.setFormatter(logging.Formatter("[%(asctime)s %(process)d] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
207+
fh.setLevel(logging.DEBUG)
208+
209+
sh = logging.StreamHandler()
210+
sh.setFormatter(logging.Formatter(''))
211+
sh.setLevel(logging.DEBUG if verbose else logging.INFO)
212+
213+
logger.setLevel(logging.DEBUG)
214+
logger.addHandler(fh)
215+
logger.addHandler(sh)
216+
217+
class LoggingCallback(Callback):
218+
def __init__(self, print_fcn=print):
219+
Callback.__init__(self)
220+
self.print_fcn = print_fcn
221+
222+
def on_epoch_end(self, epoch, logs={}):
223+
msg = "[Epoch: %i] %s" % (epoch, ", ".join("%s: %f" % (k, v) for k, v in sorted(logs.items())))
224+
self.print_fcn(msg)

0 commit comments

Comments
 (0)