Skip to content

Commit 0ba1d75

Browse files
committed
Added NT3 to Release01. Removed infer.py since I cannot test it.
1 parent 32a2f16 commit 0ba1d75

File tree

4 files changed

+90
-256
lines changed

4 files changed

+90
-256
lines changed

Pilot1/NT3/infer.py

Lines changed: 0 additions & 145 deletions
This file was deleted.

Pilot1/NT3/nt3.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
import sys
3+
4+
file_path = os.path.dirname(os.path.realpath(__file__))
5+
lib_path2 = os.path.abspath(os.path.join(file_path, '..', '..', 'common'))
6+
sys.path.append(lib_path2)
7+
8+
import candle_keras as candle
9+
10+
additional_definitions = [
11+
{'name':'model_name',
12+
'default':'nt3',
13+
'type':str},
14+
{'name':'classes',
15+
'type':int,
16+
'default':2}
17+
]
18+
19+
required = [
20+
'data_url',
21+
'train_data',
22+
'test_data',
23+
'model_name',
24+
'conv',
25+
'dense',
26+
'activation',
27+
'out_act',
28+
'loss',
29+
'optimizer',
30+
'metrics',
31+
'epochs',
32+
'batch_size',
33+
'learning_rate',
34+
'drop',
35+
'classes',
36+
'pool',
37+
'save',
38+
'timeout'
39+
]
40+
41+
class BenchmarkNT3(candle.Benchmark):
42+
43+
def set_locals(self):
44+
"""Functionality to set variables specific for the benchmark
45+
- required: set of required parameters for the benchmark.
46+
- additional_definitions: list of dictionaries describing the additional parameters for the
47+
benchmark.
48+
"""
49+
50+
if required is not None:
51+
self.required = set(required)
52+
if additional_definitions is not None:
53+
self.additional_definitions = additional_definitions
54+

Pilot1/NT3/nt3_baseline_keras2.py

Lines changed: 21 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from __future__ import print_function
2+
23
import pandas as pd
34
import numpy as np
45
import os
56
import sys
67
import gzip
7-
import argparse
8-
try:
9-
import configparser
10-
except ImportError:
11-
import ConfigParser as configparser
128

139
from keras import backend as K
1410

@@ -21,94 +17,22 @@
2117
from sklearn.metrics import accuracy_score
2218
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler
2319

24-
TIMEOUT=3600 # in sec; set this to -1 for no timeout
25-
file_path = os.path.dirname(os.path.realpath(__file__))
26-
lib_path = os.path.abspath(os.path.join(file_path, '..', 'common'))
27-
sys.path.append(lib_path)
28-
lib_path2 = os.path.abspath(os.path.join(file_path, '..', '..', 'common'))
29-
sys.path.append(lib_path2)
30-
31-
import data_utils
32-
import p1_common, p1_common_keras
33-
from solr_keras import CandleRemoteMonitor, compute_trainable_params, TerminateOnTimeOut
34-
35-
36-
#url_nt3 = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Pilot1/normal-tumor/'
37-
#file_train = 'nt_train2.csv'
38-
#file_test = 'nt_test2.csv'
39-
40-
#EPOCH = 400
41-
#BATCH = 20
42-
#CLASSES = 2
43-
44-
#PL = 60484 # 1 + 60483 these are the width of the RNAseq datasets
45-
#P = 60483 # 60483
46-
#DR = 0.1 # Dropout rate
47-
48-
def common_parser(parser):
49-
50-
parser.add_argument("--config_file", dest='config_file', type=str,
51-
default=os.path.join(file_path, 'nt3_default_model.txt'),
52-
help="specify model configuration file")
53-
54-
# Parse has been split between arguments that are common with the default neon parser
55-
# and all the other options
56-
parser = p1_common.get_default_neon_parse(parser)
57-
parser = p1_common.get_p1_common_parser(parser)
58-
59-
return parser
60-
61-
def get_nt3_parser():
62-
63-
parser = argparse.ArgumentParser(prog='nt3_baseline', formatter_class=argparse.ArgumentDefaultsHelpFormatter,
64-
description='Train Autoencoder - Pilot 1 Benchmark NT3')
65-
66-
return common_parser(parser)
67-
68-
def read_config_file(file):
69-
config = configparser.ConfigParser()
70-
config.read(file)
71-
section = config.sections()
72-
fileParams = {}
73-
74-
fileParams['data_url'] = eval(config.get(section[0],'data_url'))
75-
fileParams['train_data'] = eval(config.get(section[0],'train_data'))
76-
fileParams['test_data'] = eval(config.get(section[0],'test_data'))
77-
fileParams['model_name'] = eval(config.get(section[0],'model_name'))
78-
fileParams['conv'] = eval(config.get(section[0],'conv'))
79-
fileParams['dense'] = eval(config.get(section[0],'dense'))
80-
fileParams['activation'] = eval(config.get(section[0],'activation'))
81-
fileParams['out_act'] = eval(config.get(section[0],'out_act'))
82-
fileParams['loss'] = eval(config.get(section[0],'loss'))
83-
fileParams['optimizer'] = eval(config.get(section[0],'optimizer'))
84-
fileParams['metrics'] = eval(config.get(section[0],'metrics'))
85-
fileParams['epochs'] = eval(config.get(section[0],'epochs'))
86-
fileParams['batch_size'] = eval(config.get(section[0],'batch_size'))
87-
fileParams['learning_rate'] = eval(config.get(section[0], 'learning_rate'))
88-
fileParams['drop'] = eval(config.get(section[0],'drop'))
89-
fileParams['classes'] = eval(config.get(section[0],'classes'))
90-
fileParams['pool'] = eval(config.get(section[0],'pool'))
91-
fileParams['save'] = eval(config.get(section[0], 'save'))
92-
93-
# parse the remaining values
94-
for k,v in config.items(section[0]):
95-
if not k in fileParams:
96-
fileParams[k] = eval(v)
97-
98-
return fileParams
20+
#TIMEOUT=3600 # in sec; set this to -1 for no timeout
21+
22+
import nt3 as bmk
23+
import candle_keras as candle
9924

10025
def initialize_parameters():
101-
# Get command-line parameters
102-
parser = get_nt3_parser()
103-
args = parser.parse_args()
104-
#print('Args:', args)
105-
# Get parameters from configuration file
106-
fileParameters = read_config_file(args.config_file)
107-
#print ('Params:', fileParameters)
108-
# Consolidate parameter set. Command-line parameters overwrite file configuration
109-
gParameters = p1_common.args_overwrite_config(args, fileParameters)
110-
return gParameters
11126

27+
# Build benchmark object
28+
nt3Bmk = bmk.BenchmarkNT3(bmk.file_path, 'nt3_default_model.txt', 'keras',
29+
prog='nt3_baseline', desc='Multi-task (DNN) for data extraction from clinical reports - Pilot 3 Benchmark 1')
30+
31+
# Initialize parameters
32+
gParameters = candle.initialize_parameters(nt3Bmk)
33+
#benchmark.logger.info('Params: {}'.format(gParameters))
34+
35+
return gParameters
11236

11337
def load_data(train_path, test_path, gParameters):
11438

@@ -155,8 +79,8 @@ def run(gParameters):
15579
file_test = gParameters['test_data']
15680
url = gParameters['data_url']
15781

158-
train_file = data_utils.get_file(file_train, url+file_train, cache_subdir='Pilot1')
159-
test_file = data_utils.get_file(file_test, url+file_test, cache_subdir='Pilot1')
82+
train_file = candle.get_file(file_train, url+file_train, cache_subdir='Pilot1')
83+
test_file = candle.get_file(file_test, url+file_test, cache_subdir='Pilot1')
16084

16185
X_train, Y_train, X_test, Y_test = load_data(train_file, test_file, gParameters)
16286

@@ -231,10 +155,10 @@ def run(gParameters):
231155
#model.add(Dense(CLASSES))
232156
#model.add(Activation('softmax'))
233157

234-
kerasDefaults = p1_common.keras_default_config()
158+
kerasDefaults = candle.keras_default_config()
235159

236160
# Define optimizer
237-
optimizer = p1_common_keras.build_optimizer(gParameters['optimizer'],
161+
optimizer = candle.build_optimizer(gParameters['optimizer'],
238162
gParameters['learning_rate'],
239163
kerasDefaults)
240164

@@ -249,16 +173,16 @@ def run(gParameters):
249173
os.makedirs(output_dir)
250174

251175
# calculate trainable and non-trainable params
252-
gParameters.update(compute_trainable_params(model))
176+
gParameters.update(candle.compute_trainable_params(model))
253177

254178
# set up a bunch of callbacks to do work during model training..
255179
model_name = gParameters['model_name']
256180
path = '{}/{}.autosave.model.h5'.format(output_dir, model_name)
257181
# checkpointer = ModelCheckpoint(filepath=path, verbose=1, save_weights_only=False, save_best_only=True)
258182
csv_logger = CSVLogger('{}/training.log'.format(output_dir))
259183
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)
260-
candleRemoteMonitor = CandleRemoteMonitor(params=gParameters)
261-
timeoutMonitor = TerminateOnTimeOut(TIMEOUT)
184+
candleRemoteMonitor = candle.CandleRemoteMonitor(params=gParameters)
185+
timeoutMonitor = candle.TerminateOnTimeOut(gParameters['timeout'])
262186
history = model.fit(X_train, Y_train,
263187
batch_size=gParameters['batch_size'],
264188
epochs=gParameters['epochs'],

0 commit comments

Comments
 (0)