Skip to content

Commit 1ae7fe4

Browse files
authored
Merge pull request #24 from bvanessen/release_01
Generic logging functions and P2B1 logging
2 parents 4c1a18e + 9354df2 commit 1ae7fe4

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

Pilot2/P2B1/p2b1.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,11 @@ def train_ac(self):
328328
#for frame in random.sample(range(len(xt_all)), int(self.sampling_density*len(xt_all))):
329329
for frame in range(len(xt_all)):
330330
history = self.molecular_model.fit(xt_all[frame], yt_all[frame], epochs=1,
331-
batch_size=self.batch_size, callbacks=self.callbacks[:2],
332-
verbose=0)
331+
batch_size=self.batch_size, callbacks=self.callbacks[:2])
333332
frame_loss.append(history.history['loss'])
334333
frame_mse.append(history.history['mean_squared_error'])
335334

336335
if not frame % 20 or self.sampling_density != 1.0:
337-
print ("Frame: {0:d}, Current history:\nLoss: {1:3.5f}\tMSE: {2:3.5f}\n"
338-
.format(frame, history.history['loss'][0], history.history['mean_squared_error'][0]))
339-
340336
# Update weights filed every few frames
341337
self.molecular_model.save_weights(model_weight_file)
342338
self.molecular_encoder.save_weights(encoder_weight_file)

Pilot2/P2B1/p2b1_baseline_keras2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys, os, json
55
import argparse
66
import h5py
7+
import logging
78
try:
89
reload # Python 2.7
910
except NameError:
@@ -28,6 +29,9 @@
2829

2930
HOME = os.environ['HOME']
3031

32+
logger = logging.getLogger(__name__)
33+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
34+
3135
def parse_list(option, opt, value, parser):
3236
setattr(parser.values, option.dest, value.split(','))
3337

@@ -86,6 +90,16 @@ def run(GP):
8690
sys.exit(0)
8791
sys.path.append(GP['home_dir'])
8892

93+
# Setup loggin
94+
args = candle.ArgumentStruct(**GP)
95+
# set_seed(args.rng_seed)
96+
# ext = extension_from_parameters(args)
97+
candle.verify_path(args.save_path)
98+
prefix = args.save_path # + ext
99+
logfile = args.logfile if args.logfile else prefix+'.log'
100+
candle.set_up_logger(logfile, logger, False) #args.verbose
101+
logger.info('Params: {}'.format(GP))
102+
89103
import p2b1 as hf
90104
reload(hf)
91105

@@ -238,9 +252,10 @@ def step_decay(epoch):
238252
history = callbacks.History()
239253
# callbacks=[history,lr_scheduler]
240254

255+
history_logger = candle.LoggingCallback(logger.debug)
241256
candleRemoteMonitor = candle.CandleRemoteMonitor(params=GP)
242257
timeoutMonitor = candle.TerminateOnTimeOut(TIMEOUT)
243-
callbacks = [history, candleRemoteMonitor, timeoutMonitor]
258+
callbacks = [history, history_logger, candleRemoteMonitor, timeoutMonitor]
244259
loss = 0.
245260

246261
#### Save the Model to disk

common/candle_keras/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,8 @@
2828
from keras_utils import set_seed
2929
from keras_utils import set_parallelism_threads
3030

31+
from generic_utils import Progbar
32+
from generic_utils import LoggingCallback
33+
3134
from solr_keras import CandleRemoteMonitor, compute_trainable_params, TerminateOnTimeOut
3235

common/generic_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import numpy as np
33
import time
44
import sys
5+
import os
56
import six
67
import marshal
78
import types as python_types
8-
9+
import logging
10+
from keras.callbacks import Callback
911

1012
def get_from_module(identifier, module_params, module_name,
1113
instantiate=False, kwargs=None):
@@ -191,3 +193,12 @@ def display_row(objects, positions):
191193

192194
for objects in rows:
193195
display_row(objects, positions)
196+
197+
class LoggingCallback(Callback):
198+
def __init__(self, print_fcn=print):
199+
Callback.__init__(self)
200+
self.print_fcn = print_fcn
201+
202+
def on_epoch_end(self, epoch, logs={}):
203+
msg = "[Epoch: %i] %s" % (epoch, ", ".join("%s: %f" % (k, v) for k, v in sorted(logs.items())))
204+
self.print_fcn(msg)

0 commit comments

Comments
 (0)