Skip to content

Commit df6ad1d

Browse files
rcannoodKaiWaldrant
andcommitted
add simple_mlp
Co-authored-by: Kai Waldrant <[email protected]>
1 parent abb0d96 commit df6ad1d

File tree

9 files changed

+362
-1
lines changed

9 files changed

+362
-1
lines changed

_viash.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ authors:
6868
info:
6969
github: rcannood
7070
orcid: "0000-0003-3641-729X"
71+
- name: Xueer Chen
72+
roles: [ contributor ]
73+
info:
74+
github: xuerchen
75+
76+
- name: Jiwei Liu
77+
roles: [ contributor ]
78+
info:
79+
github: daxiongshu
80+
81+
orcid: "0000-0002-8799-9763"
7182

7283
links:
7384
issue_tracker: https://github.com/openproblems-bio/task_predict_modality/issues
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
__merge__: ../../../api/comp_method_predict.yaml
2+
name: simplemlp_predict
3+
resources:
4+
- type: python_script
5+
path: script.py
6+
- path: ../resources/
7+
engines:
8+
- type: docker
9+
image: openproblems/base_pytorch_nvidia:1.0.0
10+
# run_args: ["--gpus all --ipc=host"]
11+
setup:
12+
- type: python
13+
pypi:
14+
- scikit-learn
15+
- scanpy
16+
- pytorch-lightning
17+
engines:
18+
- type: executable
19+
- type: nextflow
20+
directives:
21+
label: [highmem, hightime, midcpu, gpu, highsharedmem]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from glob import glob
2+
import sys
3+
import numpy as np
4+
from scipy.sparse import csc_matrix
5+
import anndata as ad
6+
import torch
7+
from torch.utils.data import TensorDataset,DataLoader
8+
9+
## VIASH START
10+
par = {
11+
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad',
12+
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad',
13+
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad',
14+
'input_model': 'output/model',
15+
'output': 'output/prediction'
16+
}
17+
meta = {
18+
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp',
19+
'cpus': 10
20+
}
21+
## VIASH END
22+
23+
resources_dir = f"{meta['resources_dir']}/resources"
24+
sys.path.append(resources_dir)
25+
from models import MLP
26+
import utils
27+
28+
def _predict(model,dl):
29+
model = model.cuda()
30+
model.eval()
31+
yps = []
32+
for x in dl:
33+
with torch.no_grad():
34+
yp = model(x[0].cuda())
35+
yps.append(yp.detach().cpu().numpy())
36+
yp = np.vstack(yps)
37+
return yp
38+
39+
40+
print('Load data', flush=True)
41+
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
42+
input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])
43+
44+
# determine variables
45+
mod_1 = input_test_mod1.uns['modality']
46+
mod_2 = input_train_mod2.uns['modality']
47+
48+
task = f'{mod_1}2{mod_2}'
49+
50+
print('Load ymean', flush=True)
51+
ymean_path = f"{par['input_model']}/{task}_ymean.npy"
52+
ymean = np.load(ymean_path)
53+
54+
print('Start predict', flush=True)
55+
if task == 'GEX2ATAC':
56+
y_pred = ymean*np.ones([input_test_mod1.n_obs, input_test_mod1.n_vars])
57+
else:
58+
folds = [0, 1, 2]
59+
60+
ymean = torch.from_numpy(ymean).float()
61+
yaml_path=f"{resources_dir}/yaml/mlp_{task}.yaml"
62+
config = utils.load_yaml(yaml_path)
63+
X = input_test_mod1.layers["normalized"].toarray()
64+
X = torch.from_numpy(X).float()
65+
66+
te_ds = TensorDataset(X)
67+
68+
yp = 0
69+
for fold in folds:
70+
# load_path = f"{par['input_model']}/{task}_fold_{fold}/version_0/checkpoints/*"
71+
load_path = f"{par['input_model']}/{task}_fold_{fold}/**.ckpt"
72+
print(load_path)
73+
ckpt = glob(load_path)[0]
74+
model_inf = MLP.load_from_checkpoint(
75+
ckpt,
76+
in_dim=X.shape[1],
77+
out_dim=input_test_mod1.n_vars,
78+
ymean=ymean,
79+
config=config
80+
)
81+
te_loader = DataLoader(
82+
te_ds,
83+
batch_size=config.batch_size,
84+
num_workers=0,
85+
shuffle=False,
86+
drop_last=False
87+
)
88+
yp = yp + _predict(model_inf, te_loader)
89+
90+
y_pred = yp/len(folds)
91+
92+
y_pred = csc_matrix(y_pred)
93+
94+
adata = ad.AnnData(
95+
layers={"normalized": y_pred},
96+
shape=y_pred.shape,
97+
uns={
98+
'dataset_id': input_test_mod1.uns['dataset_id'],
99+
'method_id': meta['functionality_name'],
100+
},
101+
)
102+
103+
print('Write data', flush=True)
104+
adata.write_h5ad(par['output'], compression = "gzip")
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
__merge__: ../../../api/comp_method_train.yaml
2+
name: simplemlp
3+
label: Simple MLP
4+
summary: Ensemble of MLPs trained on different sites (team AXX)
5+
description: |
6+
This folder contains the AXX solution to the OpenProblems-NeurIPS2021 Single-Cell Multimodal Data Integration.
7+
Team took the 4th place of the modality prediction task in terms of overall ranking of 4 subtasks: namely GEX
8+
to ADT, ADT to GEX, GEX to ATAC and ATAC to GEX. Specifically, our methods ranked 3rd in GEX to ATAC and 4th
9+
in GEX to ADT. More details about the task can be found in the
10+
[competition webpage](https://openproblems.bio/events/2021-09_neurips/documentation/about_tasks/task1_modality_prediction).
11+
references:
12+
doi: 10.1101/2022.04.11.487796
13+
links:
14+
documentation: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
15+
repository: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
16+
info:
17+
preferred_normalization: log_cp10k
18+
competition_submission_id: 170812
19+
resources:
20+
- path: main.nf
21+
type: nextflow_script
22+
entrypoint: run_wf
23+
dependencies:
24+
- name: predict_modality/methods/simplemlp_train
25+
- name: predict_modality/methods/simplemlp_predict
26+
runners:
27+
- type: nextflow

src/methods/simple_mlp/run/main.nf

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
workflow run_wf {
2+
take: input_ch
3+
main:
4+
output_ch = input_ch
5+
6+
| simplemlp_train.run(
7+
fromState: ["input_train_mod1", "input_train_mod2"],
8+
toState: ["input_model": "output"]
9+
)
10+
11+
| simplemlp_predict.run(
12+
fromState: ["input_train_mod2", "input_test_mod1", "input_model", "input_transform"],
13+
toState: ["output": "output"]
14+
)
15+
16+
| map { tup ->
17+
[tup[0], [output: tup[1].output]]
18+
}
19+
20+
emit: output_ch
21+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
__merge__: ../../../api/comp_method_train.yaml
2+
name: simplemlp_train
3+
resources:
4+
- type: python_script
5+
path: script.py
6+
- path: ../resources/
7+
engines:
8+
- type: docker
9+
image: openproblems/base_pytorch_nvidia:1.0.0
10+
# run_args: ["--gpus all --ipc=host"]
11+
setup:
12+
- type: python
13+
pypi:
14+
- scikit-learn
15+
- scanpy
16+
- pytorch-lightning
17+
runners:
18+
- type: executable
19+
- type: nextflow
20+
directives:
21+
label: [highmem, hightime, midcpu, gpu, highsharedmem]
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import math
3+
import logging
4+
from pathlib import Path
5+
6+
import anndata as ad
7+
import numpy as np
8+
9+
import torch
10+
import pytorch_lightning as pl
11+
from torch.utils.data import TensorDataset, DataLoader
12+
from pytorch_lightning.callbacks import ModelCheckpoint
13+
from pytorch_lightning.loggers import TensorBoardLogger,WandbLogger
14+
15+
logging.basicConfig(level=logging.INFO)
16+
17+
## VIASH START
18+
par = {
19+
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad',
20+
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad',
21+
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad',
22+
'output': 'output/model'
23+
}
24+
meta = {
25+
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp',
26+
'cpus': 10
27+
}
28+
## VIASH END
29+
30+
resources_dir = f"{meta['resources_dir']}/resources"
31+
32+
import sys
33+
sys.path.append(resources_dir)
34+
from models import MLP
35+
import utils
36+
37+
def _train(X, y, Xt, yt, logger, config, num_workers):
38+
39+
X = torch.from_numpy(X).float()
40+
y = torch.from_numpy(y).float()
41+
ymean = torch.mean(y, dim=0, keepdim=True)
42+
43+
tr_ds = TensorDataset(X,y)
44+
tr_loader = DataLoader(
45+
tr_ds,
46+
batch_size=config.batch_size,
47+
num_workers=num_workers,
48+
shuffle=True,
49+
drop_last=True
50+
)
51+
52+
Xt = torch.from_numpy(Xt).float()
53+
yt = torch.from_numpy(yt).float()
54+
te_ds = TensorDataset(Xt,yt)
55+
te_loader = DataLoader(
56+
te_ds,
57+
batch_size=config.batch_size,
58+
num_workers=num_workers,
59+
shuffle=False,
60+
drop_last=False
61+
)
62+
63+
checkpoint_callback = ModelCheckpoint(
64+
monitor='valid_RMSE',
65+
dirpath=logger.save_dir,
66+
save_top_k=1,
67+
)
68+
69+
trainer = pl.Trainer(
70+
devices="auto",
71+
enable_checkpointing=True,
72+
logger=logger,
73+
max_epochs=config.epochs,
74+
callbacks=[checkpoint_callback],
75+
default_root_dir=logger.save_dir,
76+
# progress_bar_refresh_rate=5
77+
)
78+
79+
net = MLP(X.shape[1], y.shape[1], ymean, config)
80+
trainer.fit(net, tr_loader, te_loader)
81+
82+
yp = trainer.predict(net, te_loader, ckpt_path='best')
83+
yp = torch.cat(yp, dim=0)
84+
85+
score = ((yp-yt)**2).mean()**0.5
86+
print(f"VALID RMSE {score:.3f}")
87+
del trainer
88+
return score,yp.detach().numpy()
89+
90+
91+
92+
input_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
93+
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
94+
95+
mod_1 = input_train_mod1.uns["modality"]
96+
mod_2 = input_train_mod2.uns["modality"]
97+
98+
task = f'{mod_1}2{mod_2}'
99+
yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml'
100+
101+
obs_info = utils.to_site_donor(input_train_mod1)
102+
# TODO: if we want this method to work for other datasets, resolve dependence on site notation
103+
sites = obs_info.site.unique()
104+
105+
os.makedirs(par['output'], exist_ok=True)
106+
107+
print('Compute ymean', flush=True)
108+
ymean = np.asarray(input_train_mod2.layers["normalized"].mean(axis=0))
109+
path = f"{par['output']}/{task}_ymean.npy"
110+
np.save(path, ymean)
111+
112+
113+
if task == "GEX2ATAC":
114+
logging.info(f"No training required for this task ({task}).")
115+
sys.exit(0)
116+
117+
if not os.path.exists(yaml_path):
118+
logging.error(f"No configuration file found for task '{task}'")
119+
sys.exit(1)
120+
121+
yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml'
122+
yps = []
123+
scores = []
124+
125+
msgs = {}
126+
# TODO: if we want this method to work for other datasets, dont use hardcoded range
127+
for fold in range(3):
128+
129+
run_name = f"{task}_fold_{fold}"
130+
save_path = f"{par['output']}/{run_name}"
131+
num_workers = meta["cpus"] or 0
132+
133+
Path(save_path).mkdir(parents=True, exist_ok=True)
134+
135+
X,y,Xt,yt = utils.split(input_train_mod1, input_train_mod2, fold)
136+
137+
logger = TensorBoardLogger(save_path, name='')
138+
139+
config = utils.load_yaml(yaml_path)
140+
141+
if config.batch_size > X.shape[0]:
142+
config = config._replace(batch_size=math.ceil(X.shape[0] / 2))
143+
144+
score, yp = _train(X, y, Xt, yt, logger, config, num_workers)
145+
yps.append(yp)
146+
scores.append(score)
147+
msg = f"{task} Fold {fold} RMSE {score:.3f}"
148+
msgs[f'Fold {fold}'] = f'{score:.3f}'
149+
print(msg)
150+
151+
yp = np.concatenate(yps)
152+
score = np.mean(scores)
153+
msgs['Overall'] = f'{score:.3f}'
154+
print('Overall', f'{score:.3f}')

src/workflows/run_benchmark/config.vsh.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ dependencies:
7373
- name: methods/lm
7474
- name: methods/lmds_irlba_rf
7575
- name: methods/guanlab_dengkw_pm
76+
- name: methods/simple_mlp
7677
- name: metrics/correlation
7778
- name: metrics/mse
7879
runners:

src/workflows/run_benchmark/main.nf

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ workflow run_wf {
2121
knnr_r,
2222
lm,
2323
lmds_irlba_rf,
24-
guanlab_dengkw_pm
24+
guanlab_dengkw_pm,
25+
simple_mlp
2526
]
2627

2728
// construct list of metrics

0 commit comments

Comments
 (0)