Skip to content

Commit 4c82f17

Browse files
authored
Update p3b3.py
1 parent 8741bb9 commit 4c82f17

File tree

1 file changed

+74
-102
lines changed

1 file changed

+74
-102
lines changed

Pilot3/P3B3/p3b3.py

Lines changed: 74 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,85 @@
1-
from __future__ import absolute_import
21
from __future__ import print_function
2+
3+
import numpy as np
4+
5+
from sklearn.metrics import accuracy_score
6+
37
import os
48
import sys
59
import argparse
6-
try:
7-
import configparser
8-
except ImportError:
9-
import ConfigParser as configparser
1010

1111
file_path = os.path.dirname(os.path.realpath(__file__))
12-
lib_path = os.path.abspath(os.path.join(file_path, '..', 'common'))
13-
sys.path.append(lib_path)
1412
lib_path2 = os.path.abspath(os.path.join(file_path, '..', '..', 'common'))
1513
sys.path.append(lib_path2)
1614

17-
import p3_common
18-
19-
def common_parser(parser):
20-
21-
parser.add_argument("--config_file", dest='config_file', type=str,
22-
default=os.path.join(file_path, 'p3b3_default_model.txt'),
23-
help="specify model configuration file")
24-
25-
# Parse has been split between arguments that are common with the default neon parser
26-
# and all the other options
27-
parser = p3_common.get_default_neon_parse(parser)
28-
parser = p3_common.get_p3_common_parser(parser)
29-
30-
# Arguments that are applicable just to p3b1
31-
parser = p3b3_parser(parser)
32-
33-
return parser
34-
35-
def p3b3_parser(parser):
36-
### Hyperparameters and model save path
37-
38-
# these are leftover from other models but don't conflict so leave for now
39-
parser.add_argument("--train", action="store_true",dest="train_bool",default=True,help="Invoke training")
40-
parser.add_argument("--evaluate", action="store_true",dest="eval_bool",default=False,help="Use model for inference")
41-
parser.add_argument("--home-dir",help="Home Directory",dest="home_dir",type=str,default='.')
42-
parser.add_argument("--save-dir",help="Save Directory",dest="save_path",type=str,default=None)
43-
parser.add_argument("--config-file",help="Config File",dest="config_file",type=str,default=os.path.join(file_path, 'p3b3_default_model.txt'))
44-
parser.add_argument("--memo",help="Memo",dest="base_memo",type=str,default=None)
45-
parser.add_argument("--seed", action="store_true",dest="seed",default=False,help="Random Seed")
46-
parser.add_argument("--case",help="[Full, Center, CenterZ]",dest="case",type=str,default='CenterZ')
47-
parser.add_argument("--fig", action="store_true",dest="fig_bool",default=False,help="Generate Prediction Figure")
48-
49-
# MTL_run params start here
50-
parser.add_argument("-v", "--verbose", action="store_true",
51-
default= True,
52-
help="increase output verbosity")
53-
54-
parser.add_argument("--dropout", action="store",
55-
default=argparse.SUPPRESS, # DROPOUT, type=float,
56-
help="ratio of dropout used in fully connected layers")
57-
parser.add_argument("--learning_rate", action='store',
58-
default=argparse.SUPPRESS, # LEARNING_RATE, type=float,
59-
help='learning rate')
60-
61-
parser.add_argument("--train_features", action="store",
62-
default='data/train_X.npy',
63-
help='training feature data filenames')
64-
parser.add_argument("--train_truths", action="store",
65-
default='data/train_Y.npy',
66-
help='training truth data filenames')
67-
68-
parser.add_argument("--valid_features", action="store",
69-
default='data/test_X.npy',
70-
help='validation feature data filenames')
71-
parser.add_argument("--valid_truths", action="store",
72-
default='data/test_Y.npy',
73-
help='validation truth data filenames')
74-
75-
parser.add_argument("--output_files", action="store",
76-
default='result.csv',
77-
help="output filename")
78-
79-
# parser.add_argument("--shared_nnet_spec", action="store",
80-
# default=argparse.SUPPRESS, # DEF_SHARED_NNET_SPEC,
81-
# help='network structure of shared layer')
82-
# parser.add_argument("--individual_nnet_spec", action="store",
83-
# default=argparse.SUPPRESS, # DEF_INDIV_NNET_SPEC,
84-
# help='network structore of task-specific layer')
85-
86-
return parser
87-
88-
89-
def read_config_file(File):
90-
config=configparser.ConfigParser()
91-
config.read(File)
92-
section=config.sections()
93-
Global_Params={}
94-
95-
Global_Params['learning_rate'] =eval(config.get(section[0],'learning_rate'))
96-
Global_Params['batch_size'] =eval(config.get(section[0],'batch_size'))
97-
Global_Params['epochs'] =eval(config.get(section[0],'epochs'))
98-
Global_Params['dropout'] =eval(config.get(section[0],'dropout'))
99-
100-
Global_Params['optimizer'] =eval(config.get(section[0],'optimizer'))
101-
102-
Global_Params['wv_len'] =eval(config.get(section[0],'wv_len'))
103-
Global_Params['filter_sizes'] =eval(config.get(section[0],'filter_sizes'))
104-
Global_Params['filter_sets'] =eval(config.get(section[0],'filter_sets'))
105-
Global_Params['num_filters'] =eval(config.get(section[0],'num_filters'))
106-
107-
Global_Params['emb_l2'] =eval(config.get(section[0],'emb_l2'))
108-
Global_Params['w_l2'] =eval(config.get(section[0],'w_l2'))
15+
import candle_keras as candle
16+
17+
'''
18+
additional_definitions = [
19+
{'name':'train_features',
20+
'action':'store',
21+
'default':'data/task0_0_train_feature.csv;data/task1_0_train_feature.csv;data/task2_0_train_feature.csv',
22+
'help':'training feature data filenames'},
23+
{'name':'train_truths',
24+
'action':'store',
25+
'default':'data/task0_0_train_label.csv;data/task1_0_train_label.csv;data/task2_0_train_label.csv',
26+
'help':'training truth data filenames'},
27+
{'name':'valid_features',
28+
'action':'store',
29+
'default':'data/task0_0_test_feature.csv;data/task1_0_test_feature.csv;data/task2_0_test_feature.csv',
30+
'help':'validation feature data filenames'},
31+
{'name':'valid_truths',
32+
'action':'store',
33+
'default':'data/task0_0_test_label.csv;data/task1_0_test_label.csv;data/task2_0_test_label.csv',
34+
'help':'validation truth data filenames'},
35+
{'name':'output_files',
36+
'action':'store',
37+
'default':'result0_0.csv;result1_0.csv;result2_0.csv',
38+
'help':'output filename'},
39+
{'name':'shared_nnet_spec',
40+
'nargs':'+',
41+
'type': int,
42+
'help':'network structure of shared layer'},
43+
{'name':'ind_nnet_spec',
44+
'action':'list-of-lists',
45+
'help':'network structure of task-specific layer'},
46+
{'name':'case',
47+
'default':'CenterZ',
48+
'choices':['Full', 'Center', 'CenterZ'],
49+
'help':'case classes'},
50+
{'name':'fig',
51+
'type': candle.str2bool,
52+
'default': False,
53+
'help':'Generate Prediction Figure'},
54+
{'name':'feature_names',
55+
'nargs':'+',
56+
'type': str},
57+
{'name':'n_fold',
58+
'action':'store',
59+
'type':int}
60+
]
61+
'''
62+
63+
64+
required = [
65+
'learning_rate', 'batch_size', 'epochs', 'dropout', \
66+
'optimizer', 'wv_len', \
67+
'filter_sizes', 'filter_sets', 'num_filters', 'emb_l2', 'w_l2']
68+
69+
70+
71+
class BenchmarkP3B3(candle.Benchmark):
72+
73+
def set_locals(self):
74+
"""Functionality to set variables specific for the benchmark
75+
- required: set of required parameters for the benchmark.
76+
- additional_definitions: list of dictionaries describing the additional parameters for the
77+
benchmark.
78+
"""
79+
80+
if required is not None:
81+
self.required = set(required)
82+
# if additional_definitions is not None:
83+
# self.additional_definitions = additional_definitions
10984

11085

111-
# note 'cool' is a boolean
112-
#Global_Params['cool'] =config.get(section[0],'cool')
113-
return Global_Params

0 commit comments

Comments
 (0)