Skip to content

Commit 7c4a1f2

Browse files
committed
add UnoMT implemented in pytorch
1 parent 76b2ab6 commit 7c4a1f2

31 files changed

+6554
-0
lines changed

Pilot1/SCRATCH/UnoMT/README.md

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# UnoMT in Pytorch
2+
Multi-tasking (drug response, cell line classification, etc.) Uno Implemented in PyTorch.
3+
4+
## Todos
5+
* More labels for the network like drug labels;
6+
* Dataloader hanging problem when num_workers set to more than 0;
7+
* Better pre-processing for drug descriptor integer features;
8+
* Network regularization with weight decay and/or dropout;
9+
* Hyper-parameter searching;
10+
11+
## Prerequisites
12+
```
13+
Python 3.6.4
14+
PyTorch 0.4.1
15+
SciPy 1.1.0
16+
pandas 0.23.4
17+
Scikit-Learn 0.19.1
18+
urllib3 1.23
19+
joblib 0.12.2
20+
```
21+
22+
23+
The default network structure is shown below:
24+
<img src="https://github.com/xduan7/UnoPytorch/blob/master/images/default_network.jpg" width="100%">
25+
26+
An example of the program output for training on NCI60 and valdiation on all other data sources is shown below:
27+
```
28+
python3.6 ./launcher.py
29+
Training Arguments:
30+
{
31+
"trn_src": "NCI60",
32+
"val_srcs": [
33+
"NCI60",
34+
"CTRP",
35+
"GDSC",
36+
"CCLE",
37+
"gCSI"
38+
],
39+
"grth_scaling": "none",
40+
"dscptr_scaling": "std",
41+
"rnaseq_scaling": "std",
42+
"dscptr_nan_threshold": 0.0,
43+
"qed_scaling": "none",
44+
"rnaseq_feature_usage": "source_scale",
45+
"drug_feature_usage": "both",
46+
"validation_ratio": 0.2,
47+
"disjoint_drugs": false,
48+
"disjoint_cells": true,
49+
"gene_layer_dim": 1024,
50+
"gene_latent_dim": 512,
51+
"gene_num_layers": 2,
52+
"drug_layer_dim": 4096,
53+
"drug_latent_dim": 2048,
54+
"drug_num_layers": 2,
55+
"autoencoder_init": true,
56+
"resp_layer_dim": 2048,
57+
"resp_num_layers_per_block": 2,
58+
"resp_num_blocks": 4,
59+
"resp_num_layers": 2,
60+
"resp_dropout": 0.0,
61+
"resp_activation": "none",
62+
"cl_clf_layer_dim": 256,
63+
"cl_clf_num_layers": 2,
64+
"drug_target_layer_dim": 512,
65+
"drug_target_num_layers": 2,
66+
"drug_qed_layer_dim": 512,
67+
"drug_qed_num_layers": 2,
68+
"drug_qed_activation": "sigmoid",
69+
"resp_loss_func": "mse",
70+
"resp_opt": "SGD",
71+
"resp_lr": 1e-05,
72+
"cl_clf_opt": "SGD",
73+
"cl_clf_lr": 0.01,
74+
"drug_target_opt": "SGD",
75+
"drug_target_lr": 0.01,
76+
"drug_qed_loss_func": "mse",
77+
"drug_qed_opt": "SGD",
78+
"drug_qed_lr": 0.01,
79+
"resp_val_start_epoch": 0,
80+
"early_stop_patience": 20,
81+
"lr_decay_factor": 0.98,
82+
"trn_batch_size": 32,
83+
"val_batch_size": 256,
84+
"max_num_batches": 1000,
85+
"max_num_epochs": 1000,
86+
"multi_gpu": false,
87+
"no_cuda": false,
88+
"rand_state": 0
89+
}
90+
RespNet(
91+
(_RespNet__gene_encoder): Sequential(
92+
(dense_0): Linear(in_features=942, out_features=1024, bias=True)
93+
(relu_0): ReLU()
94+
(dense_1): Linear(in_features=1024, out_features=1024, bias=True)
95+
(relu_1): ReLU()
96+
(dense_2): Linear(in_features=1024, out_features=512, bias=True)
97+
)
98+
(_RespNet__drug_encoder): Sequential(
99+
(dense_0): Linear(in_features=4688, out_features=4096, bias=True)
100+
(relu_0): ReLU()
101+
(dense_1): Linear(in_features=4096, out_features=4096, bias=True)
102+
(relu_1): ReLU()
103+
(dense_2): Linear(in_features=4096, out_features=2048, bias=True)
104+
)
105+
(_RespNet__resp_net): Sequential(
106+
(dense_0): Linear(in_features=2561, out_features=2048, bias=True)
107+
(activation_0): ReLU()
108+
(residual_block_0): ResBlock(
109+
(block): Sequential(
110+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
111+
(res_relu_0): ReLU()
112+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
113+
)
114+
(activation): ReLU()
115+
)
116+
(residual_block_1): ResBlock(
117+
(block): Sequential(
118+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
119+
(res_relu_0): ReLU()
120+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
121+
)
122+
(activation): ReLU()
123+
)
124+
(residual_block_2): ResBlock(
125+
(block): Sequential(
126+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
127+
(res_relu_0): ReLU()
128+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
129+
)
130+
(activation): ReLU()
131+
)
132+
(residual_block_3): ResBlock(
133+
(block): Sequential(
134+
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
135+
(res_relu_0): ReLU()
136+
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
137+
)
138+
(activation): ReLU()
139+
)
140+
(dense_1): Linear(in_features=2048, out_features=2048, bias=True)
141+
(res_relu_1): ReLU()
142+
(dense_2): Linear(in_features=2048, out_features=2048, bias=True)
143+
(res_relu_2): ReLU()
144+
(dense_out): Linear(in_features=2048, out_features=1, bias=True)
145+
)
146+
)
147+
================================================================================
148+
Training Epoch 1:
149+
Drug Weighted QED Regression Loss: 0.055694
150+
Drug Response Regression Loss: 1871.18
151+
152+
Validation Results:
153+
Cell Line Classification:
154+
Category Accuracy: 98.98%;
155+
Site Accuracy: 80.95%;
156+
Type Accuracy: 82.76%
157+
Drug Target Family Classification Accuracy: 1.85%
158+
Drug Weighted QED Regression
159+
MSE: 0.028476 MAE: 0.137004 R2: +0.17
160+
Drug Response Regression:
161+
NCI60 MSE: 1482.07 MAE: 27.89 R2: +0.53
162+
CTRP MSE: 2554.45 MAE: 38.62 R2: +0.27
163+
GDSC MSE: 2955.78 MAE: 42.73 R2: +0.11
164+
CCLE MSE: 2799.06 MAE: 42.44 R2: +0.31
165+
gCSI MSE: 2601.50 MAE: 38.44 R2: +0.35
166+
Epoch Running Time: 110.0 Seconds.
167+
================================================================================
168+
Training Epoch 2:
169+
...
170+
...
171+
172+
Program Running Time: 8349.6 Seconds.
173+
================================================================================
174+
Overall Validation Results:
175+
176+
Best Results from Different Models (Epochs):
177+
Cell Line Categories Best Accuracy: 99.474% (Epoch = 5)
178+
Cell Line Sites Best Accuracy: 97.401% (Epoch = 60)
179+
Cell Line Types Best Accuracy: 97.368% (Epoch = 40)
180+
Drug Target Family Best Accuracy: 66.667% (Epoch = 23)
181+
Drug Weighted QED Best R2 Score: +0.7422 (Epoch = 59, MSE = 0.008837, MAE = 0.069400)
182+
NCI60 Best R2 Score: +0.8107 (Epoch = 56, MSE = 601.18, MAE = 16.57)
183+
CTRP Best R2 Score: +0.3945 (Epoch = 37, MSE = 2127.28, MAE = 31.44)
184+
GDSC Best R2 Score: +0.2448 (Epoch = 22, MSE = 2506.03, MAE = 35.55)
185+
CCLE Best R2 Score: +0.4729 (Epoch = 4, MSE = 2153.30, MAE = 33.63)
186+
gCSI Best R2 Score: +0.4512 (Epoch = 31, MSE = 2203.04, MAE = 32.63)
187+
188+
Best Results from the Same Model (Epoch = 22):
189+
Cell Line Categories Accuracy: 99.408%
190+
Cell Line Sites Accuracy: 97.138%
191+
Cell Line Types Accuracy: 97.039%
192+
Drug Target Family Accuracy: 57.407%
193+
Drug Weighted QED R2 Score: +0.6033 (MSE = 0.013601, MAE = 0.093341)
194+
NCI60 R2 Score: +0.7885 (MSE = 672.00, MAE = 17.89)
195+
CTRP R2 Score: +0.3841 (MSE = 2163.66, MAE = 32.28)
196+
GDSC R2 Score: +0.2448 (MSE = 2506.03, MAE = 35.55)
197+
CCLE R2 Score: +0.4653 (MSE = 2184.62, MAE = 34.12)
198+
gCSI R2 Score: +0.4271 (MSE = 2299.59, MAE = 32.93)
199+
```
200+
201+
For default hyper parameters, the transfer learning matrix results are shown below:
202+
<p align="center">
203+
<img src="https://github.com/xduan7/UnoPytorch/blob/master/images/default_results.jpg" width="80%">
204+
</p>
205+
206+
Note that the green cells represents R2 score of higher than 0.1, red cells are R2 scores lower than -0.1 and yellows are for all the values in between.

0 commit comments

Comments
 (0)