Skip to content

Commit e5c1083

Browse files
authored
Merge pull request #44 from ECP-CANDLE/develop
Updated save / load weights for multi-gpu model
2 parents 925b790 + 870e8c3 commit e5c1083

File tree

2 files changed

+344
-1
lines changed

2 files changed

+344
-1
lines changed

Pilot1/Uno/uno_baseline_keras2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,17 @@ def build_feature_model(input_shape, name='', dense_layers=[1000, 1000],
201201
model = Model(x_input, h, name=name)
202202
return model
203203

204-
class SimpleWeightSaver(Callback):
204+
class SimpleWeightSaver(Callback):
205+
205206
def __init__(self, fname):
206207
self.fname = fname
207208

209+
def set_model(self, model):
210+
if isinstance(model.layers[-2], Model):
211+
self.model = model.layers[-2]
212+
else:
213+
self.model = model
214+
208215
def on_train_end(self, logs={}):
209216
self.model.save_weights(self.fname)
210217

@@ -402,6 +409,7 @@ def warmup_scheduler(epoch):
402409
if len(args.gpus) > 1:
403410
from keras.utils import multi_gpu_model
404411
gpu_count = len(args.gpus)
412+
logger.info("Multi GPU with {} gpus".format(gpu_count))
405413
model = multi_gpu_model(template_model, cpu_merge=False, gpus=gpu_count)
406414
else:
407415
model = template_model
@@ -411,6 +419,7 @@ def warmup_scheduler(epoch):
411419
if args.learning_rate:
412420
K.set_value(optimizer.lr, args.learning_rate)
413421

422+
414423
model.compile(loss=args.loss, optimizer=optimizer, metrics=[mae, r2])
415424

416425
# calculate trainable and non-trainable params

Pilot1/UnoMT/README.md

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# UnoMT in Pytorch
2+
Multi-tasking (drug response, cell line classification, etc.) Uno Implemented in PyTorch.
3+
https://github.com/xduan7/UnoPytorch
4+
5+
6+
## Todos
7+
* More labels for the network like drug labels;
8+
* Dataloader hanging problem when num_workers set to more than 0;
9+
* Better pre-processing for drug descriptor integer features;
10+
* Network regularization with weight decay and/or dropout;
11+
* Hyper-parameter searching;
12+
13+
## Prerequisites
14+
```
15+
Python 3.6.4
16+
PyTorch 0.4.1
17+
SciPy 1.1.0
18+
pandas 0.23.4
19+
Scikit-Learn 0.19.1
20+
urllib3 1.23
21+
joblib 0.12.2
22+
```
23+
24+
25+
The default network structure is shown below:
26+
<img src="https://github.com/xduan7/UnoPytorch/blob/master/images/default_network.jpg" width="100%">
27+
28+
An example of the program output for training on NCI60 and valdiation on all other data sources is shown below:
29+
```
30+
python unoMT_baseline_pytorch.py --resp_val_start_epoch 2 --epochs 5
31+
Importing candle utils for pytorch
32+
Created unoMT benchmark
33+
Configuration file: ./unoMT_default_model.txt
34+
{'autoencoder_init': True,
35+
'cl_clf_layer_dim': 256,
36+
'cl_clf_lr': 0.008,
37+
'cl_clf_num_layers': 2,
38+
'cl_clf_opt': 'SGD',
39+
'disjoint_cells': True,
40+
'disjoint_drugs': False,
41+
'drop': 0.1,
42+
'drug_feature_usage': 'both',
43+
'drug_latent_dim': 1024,
44+
'drug_layer_dim': 4096,
45+
'drug_num_layers': 2,
46+
'drug_qed_activation': 'sigmoid',
47+
'drug_qed_layer_dim': 1024,
48+
'drug_qed_loss_func': 'mse',
49+
'drug_qed_lr': 0.01,
50+
'drug_qed_num_layers': 2,
51+
'drug_qed_opt': 'SGD',
52+
'drug_target_layer_dim': 1024,
53+
'drug_target_lr': 0.002,
54+
'drug_target_num_layers': 2,
55+
'drug_target_opt': 'SGD',
56+
'dscptr_nan_threshold': 0.0,
57+
'dscptr_scaling': 'std',
58+
'early_stop_patience': 5,
59+
'epochs': 1000,
60+
'gene_latent_dim': 512,
61+
'gene_layer_dim': 1024,
62+
'gene_num_layers': 2,
63+
'grth_scaling': 'none',
64+
'l2_regularization': 1e-05,
65+
'lr_decay_factor': 0.98,
66+
'max_num_batches': 1000,
67+
'qed_scaling': 'none',
68+
'resp_activation': 'none',
69+
'resp_layer_dim': 2048,
70+
'resp_loss_func': 'mse',
71+
'resp_lr': 1e-05,
72+
'resp_num_blocks': 4,
73+
'resp_num_layers': 2,
74+
'resp_num_layers_per_block': 2,
75+
'resp_opt': 'SGD',
76+
'resp_val_start_epoch': 0,
77+
'rnaseq_feature_usage': 'combat',
78+
'rnaseq_scaling': 'std',
79+
'rng_seed': 0,
80+
'save_path': 'save/unoMT',
81+
'solr_root': '',
82+
'timeout': 3600,
83+
'train_sources': 'NCI60',
84+
'trn_batch_size': 32,
85+
'val_batch_size': 256,
86+
'val_sources': ['NCI60', 'CTRP', 'GDSC', 'CCLE', 'gCSI'],
87+
'val_split': 0.2}
88+
Params:
89+
{'autoencoder_init': True,
90+
'cl_clf_layer_dim': 256,
91+
'cl_clf_lr': 0.008,
92+
'cl_clf_num_layers': 2,
93+
'cl_clf_opt': 'SGD',
94+
'datatype': <class 'numpy.float32'>,
95+
'disjoint_cells': True,
96+
'disjoint_drugs': False,
97+
'drop': 0.1,
98+
'drug_feature_usage': 'both',
99+
'drug_latent_dim': 1024,
100+
'drug_layer_dim': 4096,
101+
'drug_num_layers': 2,
102+
'drug_qed_activation': 'sigmoid',
103+
'drug_qed_layer_dim': 1024,
104+
'drug_qed_loss_func': 'mse',
105+
'drug_qed_lr': 0.01,
106+
'drug_qed_num_layers': 2,
107+
'drug_qed_opt': 'SGD',
108+
'drug_target_layer_dim': 1024,
109+
'drug_target_lr': 0.002,
110+
'drug_target_num_layers': 2,
111+
'drug_target_opt': 'SGD',
112+
'dscptr_nan_threshold': 0.0,
113+
'dscptr_scaling': 'std',
114+
'early_stop_patience': 5,
115+
'epochs': 5,
116+
'experiment_id': 'EXP000',
117+
'gene_latent_dim': 512,
118+
'gene_layer_dim': 1024,
119+
'gene_num_layers': 2,
120+
'gpus': [],
121+
'grth_scaling': 'none',
122+
'l2_regularization': 1e-05,
123+
'logfile': None,
124+
'lr_decay_factor': 0.98,
125+
'max_num_batches': 1000,
126+
'multi_gpu': False,
127+
'no_cuda': False,
128+
'output_dir': '/home/jamal/Code/ECP/CANDLE/Benchmarks/Pilot1/UnoMT/Output/EXP000/RUN000',
129+
'qed_scaling': 'none',
130+
'resp_activation': 'none',
131+
'resp_layer_dim': 2048,
132+
'resp_loss_func': 'mse',
133+
'resp_lr': 1e-05,
134+
'resp_num_blocks': 4,
135+
'resp_num_layers': 2,
136+
'resp_num_layers_per_block': 2,
137+
'resp_opt': 'SGD',
138+
'resp_val_start_epoch': 2,
139+
'rnaseq_feature_usage': 'combat',
140+
'rnaseq_scaling': 'std',
141+
'rng_seed': 0,
142+
'run_id': 'RUN000',
143+
'save_path': 'save/unoMT',
144+
'shuffle': False,
145+
'solr_root': '',
146+
'timeout': 3600,
147+
'train_bool': True,
148+
'train_sources': 'NCI60',
149+
'trn_batch_size': 32,
150+
'val_batch_size': 256,
151+
'val_sources': ['NCI60', 'CTRP', 'GDSC', 'CCLE', 'gCSI'],
152+
'val_split': 0.2,
153+
'verbose': None}
154+
Parameters initialized
155+
Failed to split NCI60 cells in stratified way. Splitting randomly ...
156+
Failed to split NCI60 cells in stratified way. Splitting randomly ...
157+
Failed to split CCLE cells in stratified way. Splitting randomly ...
158+
Failed to split CCLE drugs stratified on growth and correlation. Splitting solely on avg growth ...
159+
Failed to split gCSI drugs stratified on growth and correlation. Splitting solely on avg growth ...
160+
RespNet(
161+
(_RespNet__gene_encoder): Sequential(
162+
(dense_0): Linear(in_features=942, out_features=1024, bias=True)
163+
(relu_0): ReLU()
164+
(dense_1): Linear(in_features=1024, out_features=1024, bias=True)
165+
(relu_1): ReLU()
166+
(dense_2): Linear(in_features=1024, out_features=512, bias=True)
167+
)
168+
(_RespNet__drug_encoder): Sequential(
169+
(dense_0): Linear(in_features=4688, out_features=4096, bias=True)
170+
(relu_0): ReLU()
171+
(dense_1): Linear(in_features=4096, out_features=4096, bias=True)
172+
(relu_1): ReLU()
173+
(dense_2): Linear(in_features=4096, out_features=1024, bias=True)
174+
)
175+
(_RespNet__resp_net): Sequential(
176+
(dense_0): Linear(in_features=1537, out_features=2048, bias=True)
177+
(activation_0): ReLU()
178+
(residual_block_0): ResBlock(
179+
(block): Sequential(
180+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
181+
(res_dropout_0): Dropout(p=0.1)
182+
(res_relu_0): ReLU()
183+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
184+
(res_dropout_1): Dropout(p=0.1)
185+
)
186+
(activation): ReLU()
187+
)
188+
(residual_block_1): ResBlock(
189+
(block): Sequential(
190+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
191+
(res_dropout_0): Dropout(p=0.1)
192+
(res_relu_0): ReLU()
193+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
194+
(res_dropout_1): Dropout(p=0.1)
195+
)
196+
(activation): ReLU()
197+
)
198+
(residual_block_2): ResBlock(
199+
(block): Sequential(
200+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
201+
(res_dropout_0): Dropout(p=0.1)
202+
(res_relu_0): ReLU()
203+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
204+
(res_dropout_1): Dropout(p=0.1)
205+
)
206+
(activation): ReLU()
207+
)
208+
(residual_block_3): ResBlock(
209+
(block): Sequential(
210+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
211+
(res_dropout_0): Dropout(p=0.1)
212+
(res_relu_0): ReLU()
213+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
214+
(res_dropout_1): Dropout(p=0.1)
215+
)
216+
(activation): ReLU()
217+
)
218+
(dense_1): Linear(in_features=2048, out_features=2048, bias=True)
219+
(dropout_1): Dropout(p=0.1)
220+
(res_relu_1): ReLU()
221+
(dense_2): Linear(in_features=2048, out_features=2048, bias=True)
222+
(dropout_2): Dropout(p=0.1)
223+
(res_relu_2): ReLU()
224+
(dense_out): Linear(in_features=2048, out_features=1, bias=True)
225+
)
226+
)
227+
Data sizes:
228+
Train:
229+
Data set: NCI60 Size: 882873
230+
Validation:
231+
Data set: NCI60 Size: 260286
232+
Data set: CTRP Size: 1040021
233+
Data set: GDSC Size: 235812
234+
Data set: CCLE Size: 17510
235+
Data set: gCSI Size: 10323
236+
================================================================================
237+
Training Epoch 1:
238+
Drug Weighted QED Regression Loss: 0.022274
239+
Drug Response Regression Loss: 1881.89
240+
Epoch Running Time: 13.2 Seconds.
241+
================================================================================
242+
Training Epoch 2:
243+
Drug Weighted QED Regression Loss: 0.019416
244+
Drug Response Regression Loss: 1348.13
245+
Epoch Running Time: 12.9 Seconds.
246+
================================================================================
247+
Training Epoch 3:
248+
Drug Weighted QED Regression Loss: 0.015868
249+
Drug Response Regression Loss: 1123.27
250+
Cell Line Classification:
251+
Category Accuracy: 99.01%;
252+
Site Accuracy: 94.11%;
253+
Type Accuracy: 94.18%
254+
Drug Target Family Classification Accuracy: 44.44%
255+
Drug Weighted QED Regression
256+
MSE: 0.018845 MAE: 0.111807 R2: +0.45
257+
Drug Response Regression:
258+
NCI60 MSE: 973.04 MAE: 22.18 R2: +0.69
259+
CTRP MSE: 2404.64 MAE: 34.04 R2: +0.32
260+
GDSC MSE: 2717.81 MAE: 36.53 R2: +0.19
261+
CCLE MSE: 2518.47 MAE: 36.60 R2: +0.38
262+
gCSI MSE: 2752.33 MAE: 36.97 R2: +0.35
263+
Epoch Running Time: 54.6 Seconds.
264+
================================================================================
265+
Training Epoch 4:
266+
Drug Weighted QED Regression Loss: 0.014096
267+
Drug Response Regression Loss: 933.27
268+
Cell Line Classification:
269+
Category Accuracy: 99.34%;
270+
Site Accuracy: 96.12%;
271+
Type Accuracy: 96.18%
272+
Drug Target Family Classification Accuracy: 44.44%
273+
Drug Weighted QED Regression
274+
MSE: 0.018467 MAE: 0.110287 R2: +0.46
275+
Drug Response Regression:
276+
NCI60 MSE: 844.51 MAE: 20.41 R2: +0.73
277+
CTRP MSE: 2314.19 MAE: 33.76 R2: +0.35
278+
GDSC MSE: 2747.73 MAE: 36.65 R2: +0.18
279+
CCLE MSE: 2482.03 MAE: 35.89 R2: +0.39
280+
gCSI MSE: 2665.35 MAE: 36.27 R2: +0.37
281+
Epoch Running Time: 54.9 Seconds.
282+
================================================================================
283+
Training Epoch 5:
284+
Drug Weighted QED Regression Loss: 0.013514
285+
Drug Response Regression Loss: 846.06
286+
Cell Line Classification:
287+
Category Accuracy: 99.38%;
288+
Site Accuracy: 95.89%;
289+
Type Accuracy: 95.30%
290+
Drug Target Family Classification Accuracy: 44.44%
291+
Drug Weighted QED Regression
292+
MSE: 0.017026 MAE: 0.106697 R2: +0.50
293+
Drug Response Regression:
294+
NCI60 MSE: 835.82 MAE: 21.33 R2: +0.74
295+
CTRP MSE: 2653.04 MAE: 37.98 R2: +0.25
296+
GDSC MSE: 2892.86 MAE: 39.76 R2: +0.13
297+
CCLE MSE: 2412.75 MAE: 36.82 R2: +0.41
298+
gCSI MSE: 2888.99 MAE: 38.70 R2: +0.32
299+
Epoch Running Time: 55.5 Seconds.
300+
Program Running Time: 191.1 Seconds.
301+
================================================================================
302+
Overall Validation Results:
303+
304+
Best Results from Different Models (Epochs):
305+
Cell Line Categories Best Accuracy: 99.375% (Epoch = 5)
306+
Cell Line Sites Best Accuracy: 96.118% (Epoch = 4)
307+
Cell Line Types Best Accuracy: 96.184% (Epoch = 4)
308+
Drug Target Family Best Accuracy: 44.444% (Epoch = 3)
309+
Drug Weighted QED Best R2 Score: +0.5034 (Epoch = 5, MSE = 0.017026, MAE = 0.106697)
310+
NCI60 Best R2 Score: +0.7369 (Epoch = 5, MSE = 835.82, MAE = 21.33)
311+
CTRP Best R2 Score: +0.3469 (Epoch = 4, MSE = 2314.19, MAE = 33.76)
312+
GDSC Best R2 Score: +0.1852 (Epoch = 3, MSE = 2717.81, MAE = 36.53)
313+
CCLE Best R2 Score: +0.4094 (Epoch = 5, MSE = 2412.75, MAE = 36.82)
314+
gCSI Best R2 Score: +0.3693 (Epoch = 4, MSE = 2665.35, MAE = 36.27)
315+
316+
Best Results from the Same Model (Epoch = 5):
317+
Cell Line Categories Accuracy: 99.375%
318+
Cell Line Sites Accuracy: 95.888%
319+
Cell Line Types Accuracy: 95.296%
320+
Drug Target Family Accuracy: 44.444%
321+
Drug Weighted QED R2 Score: +0.5034 (MSE = 0.017026, MAE = 0.106697)
322+
NCI60 R2 Score: +0.7369 (MSE = 835.82, MAE = 21.33)
323+
CTRP R2 Score: +0.2513 (MSE = 2653.04, MAE = 37.98)
324+
GDSC R2 Score: +0.1327 (MSE = 2892.86, MAE = 39.76)
325+
CCLE R2 Score: +0.4094 (MSE = 2412.75, MAE = 36.82)
326+
gCSI R2 Score: +0.3164 (MSE = 2888.99, MAE = 38.70)
327+
```
328+
329+
For default hyper parameters, the transfer learning matrix results are shown below:
330+
<p align="center">
331+
<img src="https://github.com/xduan7/UnoPytorch/blob/master/images/default_results.jpg" width="80%">
332+
</p>
333+
334+
Note that the green cells represents R2 score of higher than 0.1, red cells are R2 scores lower than -0.1 and yellows are for all the values in between.

0 commit comments

Comments
 (0)