Skip to content

Commit b7edd10

Browse files
committed
add mxnet implementation
1 parent d54b7f4 commit b7edd10

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

P1B3/p1b3_baseline_mxnet.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
from __future__ import division, print_function
2+
3+
import argparse
4+
import logging
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
import mxnet as mx
10+
from mxnet.io import DataBatch, DataIter
11+
12+
# For non-interactive plotting
13+
import matplotlib as mpl
14+
mpl.use('Agg')
15+
import matplotlib.pyplot as plt
16+
17+
import p1b3
18+
19+
20+
# Model and Training parameters
21+
22+
# Seed for random generation
23+
SEED = 2016
24+
# Size of batch for training
25+
BATCH_SIZE = 100
26+
# Number of training epochs
27+
NB_EPOCH = 20
28+
# Number of data generator workers
29+
NB_WORKER = 1
30+
31+
# Percentage of dropout used in training
32+
DROP = 0.1
33+
# Activation function (options: 'relu', 'tanh', 'sigmoid', 'hard_sigmoid', 'linear')
34+
ACTIVATION = 'relu'
35+
LOSS = 'mse'
36+
OPTIMIZER = 'sgd'
37+
38+
# Type of feature scaling (options: 'maxabs': to [-1,1]
39+
# 'minmax': to [0,1]
40+
# None : standard normalization
41+
SCALING = 'std'
42+
# Features to (randomly) sample from cell lines or drug descriptors
43+
FEATURE_SUBSAMPLE = 500#0
44+
# FEATURE_SUBSAMPLE = 0
45+
46+
# Number of units in fully connected (dense) layers
47+
D1 = 1000
48+
D2 = 500
49+
D3 = 100
50+
D4 = 50
51+
DENSE_LAYERS = [D1, D2, D3, D4]
52+
53+
# Number of units per locally connected layer
54+
C1 = 10, 10, 5 # nb_filter, filter_length, stride
55+
C2 = 0, 0, 0 # disabled layer
56+
# CONVOLUTION_LAYERS = list(C1 + C2)
57+
CONVOLUTION_LAYERS = [0, 0, 0]
58+
POOL = 10
59+
60+
MIN_LOGCONC = -5.
61+
MAX_LOGCONC = -4.
62+
63+
CATEGORY_CUTOFFS = [0.]
64+
65+
np.set_printoptions(threshold=np.nan)
66+
np.random.seed(SEED)
67+
68+
69+
def get_parser():
70+
parser = argparse.ArgumentParser(prog='p1b3_baseline',
71+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
72+
parser.add_argument("-v", "--verbose", action="store_true",
73+
help="increase output verbosity")
74+
parser.add_argument("-a", "--activation", action="store",
75+
default=ACTIVATION,
76+
help="keras activation function to use in inner layers: relu, tanh, sigmoid...")
77+
parser.add_argument("-b", "--batch_size", action="store",
78+
default=BATCH_SIZE, type=int,
79+
help="batch size")
80+
parser.add_argument("-c", "--convolution", action="store", nargs='+', type=int,
81+
default=CONVOLUTION_LAYERS,
82+
help="integer array describing convolution layers: conv1_nb_filter, conv1_filter_len, conv1_stride, conv2_nb_filter, conv2_filter_len, conv2_stride ...")
83+
parser.add_argument("-d", "--dense", action="store", nargs='+', type=int,
84+
default=DENSE_LAYERS,
85+
help="number of units in fully connected layers in an integer array")
86+
parser.add_argument("-e", "--epochs", action="store",
87+
default=NB_EPOCH, type=int,
88+
help="number of training epochs")
89+
parser.add_argument("-l", "--locally_connected", action="store_true",
90+
default=False, # TODO: not currently supported
91+
help="use locally connected layers instead of convolution layers")
92+
parser.add_argument("-o", "--optimizer", action="store",
93+
default=OPTIMIZER,
94+
help="keras optimizer to use: sgd, rmsprop, ...")
95+
parser.add_argument("--drop", action="store",
96+
default=DROP, type=float,
97+
help="ratio of dropout used in fully connected layers")
98+
parser.add_argument("--loss", action="store",
99+
default=LOSS,
100+
help="keras loss function to use: mse, ...")
101+
parser.add_argument("--pool", action="store",
102+
default=POOL, type=int,
103+
help="pooling layer length")
104+
parser.add_argument("--scaling", action="store",
105+
default=SCALING,
106+
help="type of feature scaling; 'minabs': to [-1,1]; 'minmax': to [0,1], 'std': standard unit normalization; None: no normalization")
107+
parser.add_argument("--drug_features", action="store",
108+
default="descriptors",
109+
help="use dragon7 descriptors, latent representations from Aspuru-Guzik's SMILES autoencoder, or both, or random features; 'descriptors','latent', 'both', 'noise'")
110+
parser.add_argument("--feature_subsample", action="store",
111+
default=FEATURE_SUBSAMPLE, type=int,
112+
help="number of features to randomly sample from each category (cellline expression, drug descriptors, etc), 0 means using all features")
113+
parser.add_argument("--min_logconc", action="store",
114+
default=MIN_LOGCONC, type=float,
115+
help="min log concentration of dose response data to use: -3.0 to -7.0")
116+
parser.add_argument("--max_logconc", action="store",
117+
default=MAX_LOGCONC, type=float,
118+
help="max log concentration of dose response data to use: -3.0 to -7.0")
119+
parser.add_argument("--subsample", action="store",
120+
default='naive_balancing',
121+
help="dose response subsample strategy; None or 'naive_balancing'")
122+
parser.add_argument("--category_cutoffs", action="store", nargs='+', type=float,
123+
default=CATEGORY_CUTOFFS,
124+
help="list of growth cutoffs (between -1 and +1) seperating non-response and response categories")
125+
parser.add_argument("--train_samples", action="store",
126+
default=0, type=int,
127+
help="overrides the number of training samples if set to nonzero")
128+
parser.add_argument("--val_samples", action="store",
129+
default=0, type=int,
130+
help="overrides the number of validation samples if set to nonzero")
131+
parser.add_argument("--save", action="store",
132+
default='save',
133+
help="prefix of output files")
134+
parser.add_argument("--scramble", action="store_true",
135+
help="randomly shuffle dose response data")
136+
parser.add_argument("--workers", action="store",
137+
default=NB_WORKER, type=int,
138+
help="number of data generator workers")
139+
parser.add_argument("--gpus", action="store", nargs='*',
140+
default=[], type=int,
141+
help="set IDs of GPUs to use")
142+
143+
return parser
144+
145+
146+
def extension_from_parameters(args):
147+
"""Construct string for saving model with annotation of parameters"""
148+
ext = '.mx'
149+
ext += '.A={}'.format(args.activation)
150+
ext += '.B={}'.format(args.batch_size)
151+
ext += '.D={}'.format(args.drop)
152+
ext += '.E={}'.format(args.epochs)
153+
if args.feature_subsample:
154+
ext += '.F={}'.format(args.feature_subsample)
155+
if args.convolution:
156+
name = 'LC' if args.locally_connected else 'C'
157+
layer_list = list(range(0, len(args.convolution), 3))
158+
for l, i in enumerate(layer_list):
159+
nb_filter = args.convolution[i]
160+
filter_len = args.convolution[i+1]
161+
stride = args.convolution[i+2]
162+
if nb_filter <= 0 or filter_len <= 0 or stride <= 0:
163+
break
164+
ext += '.{}{}={},{},{}'.format(name, l+1, nb_filter, filter_len, stride)
165+
if args.pool and layer_list[0] and layer_list[1]:
166+
ext += '.P={}'.format(args.pool)
167+
for i, n in enumerate(args.dense):
168+
if n:
169+
ext += '.D{}={}'.format(i+1, n)
170+
ext += '.S={}'.format(args.scaling)
171+
172+
return ext
173+
174+
175+
class ConcatDataIter(DataIter):
176+
"""Data iterator for concatenated features
177+
"""
178+
179+
def __init__(self, data_loader,
180+
partition='train',
181+
batch_size=32,
182+
num_data=None,
183+
shape=None):
184+
super(ConcatDataIter, self).__init__()
185+
self.data = data_loader
186+
self.batch_size = batch_size
187+
self.gen = p1b3.DataGenerator(data_loader, partition=partition, batch_size=batch_size, shape=shape, concat=True)
188+
self.num_data = num_data or self.gen.num_data
189+
self.cursor = 0
190+
self.gen = self.gen.flow()
191+
192+
@property
193+
def provide_data(self):
194+
return [('concat_features', (self.batch_size, self.data.input_dim))]
195+
196+
@property
197+
def provide_label(self):
198+
return [('growth', (self.batch_size,))]
199+
200+
def reset(self):
201+
self.cursor = 0
202+
203+
def iter_next(self):
204+
self.cursor += self.batch_size
205+
if self.cursor <= self.num_data:
206+
return True
207+
else:
208+
return False
209+
210+
def next(self):
211+
if self.iter_next():
212+
x, y = next(self.gen)
213+
return DataBatch(data=[mx.nd.array(x)], label=[mx.nd.array(y)])
214+
else:
215+
raise StopIteration
216+
217+
218+
def plot_network(net, filename):
219+
try:
220+
dot = mx.viz.plot_network(net)
221+
except ImportError:
222+
return
223+
try:
224+
dot.render(filename, view=False)
225+
print('Plotted network architecture in {}'.format(filename+'.pdf'))
226+
except Exception:
227+
return
228+
229+
230+
def main():
231+
parser = get_parser()
232+
args = parser.parse_args()
233+
print('Args:', args)
234+
235+
# it = RegressionDataIter()
236+
237+
loggingLevel = logging.DEBUG if args.verbose else logging.INFO
238+
logging.basicConfig(level=loggingLevel, format='')
239+
240+
ext = extension_from_parameters(args)
241+
242+
loader = p1b3.DataLoader(feature_subsample=args.feature_subsample,
243+
scaling=args.scaling,
244+
drug_features=args.drug_features,
245+
scramble=args.scramble,
246+
min_logconc=args.min_logconc,
247+
max_logconc=args.max_logconc,
248+
subsample=args.subsample,
249+
category_cutoffs=args.category_cutoffs)
250+
251+
net = mx.sym.Variable('concat_features')
252+
out = mx.sym.Variable('growth')
253+
254+
if args.convolution and args.convolution[0]:
255+
net = mx.sym.Reshape(data=net, shape=(args.batch_size, 1, loader.input_dim, 1))
256+
layer_list = list(range(0, len(args.convolution), 3))
257+
for l, i in enumerate(layer_list):
258+
nb_filter = args.convolution[i]
259+
filter_len = args.convolution[i+1]
260+
stride = args.convolution[i+2]
261+
if nb_filter <= 0 or filter_len <= 0 or stride <= 0:
262+
break
263+
net = mx.sym.Convolution(data=net, num_filter=nb_filter, kernel=(filter_len, 1), stride=(stride, 1))
264+
net = mx.sym.Activation(data=net, act_type=args.activation)
265+
if args.pool:
266+
net = mx.sym.Pooling(data=net, pool_type="max", kernel=(args.pool, 1), stride=(1, 1))
267+
268+
for layer in args.dense:
269+
if layer:
270+
net = mx.sym.FullyConnected(data=net, num_hidden=layer)
271+
net = mx.sym.Activation(data=net, act_type=args.activation)
272+
if args.drop:
273+
net = mx.sym.Dropout(data=net, p=args.drop)
274+
net = mx.sym.FullyConnected(data=net, num_hidden=1)
275+
net = mx.symbol.LinearRegressionOutput(data=net, label=out)
276+
277+
plot_network(net, 'net'+ext)
278+
279+
train_iter = ConcatDataIter(loader, batch_size=args.batch_size, num_data=args.train_samples)
280+
val_iter = ConcatDataIter(loader, partition='val', batch_size=args.batch_size, num_data=args.val_samples)
281+
282+
devices = mx.cpu()
283+
if args.gpus:
284+
devices = [mx.gpu(i) for i in args.gpus]
285+
286+
mod = mx.mod.Module(net,
287+
data_names=('concat_features',),
288+
label_names=('growth',),
289+
context=devices)
290+
291+
mod.fit(train_iter, eval_data=val_iter,
292+
eval_metric=args.loss,
293+
optimizer=args.optimizer,
294+
num_epoch=args.epochs,
295+
batch_end_callback = mx.callback.Speedometer(args.batch_size, 20))
296+
297+
298+
if __name__ == '__main__':
299+
main()

0 commit comments

Comments
 (0)