Skip to content

Commit 84ae0ef

Browse files
committed
add tutorial
1 parent 204e50c commit 84ae0ef

File tree

2 files changed

+2205
-0
lines changed

2 files changed

+2205
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import argparse
2+
import gc
3+
import os
4+
import pprint
5+
import random
6+
import string
7+
from pathlib import Path
8+
from typing import get_args
9+
10+
import torch
11+
from sklearn.random_projection import GaussianRandomProjection
12+
13+
import wandb
14+
from dance import logger
15+
from dance.datasets.singlemodality import CellTypeAnnotationDataset
16+
from dance.modules.single_modality.cell_type_annotation.svm import SVM
17+
from dance.pipeline import PipelinePlaner, get_step3_yaml, run_step3, save_summary_data
18+
from dance.registry import register_preprocessor
19+
from dance.transforms.base import BaseTransform
20+
from dance.typing import LogLevel
21+
from dance.utils import set_seed
22+
23+
24+
@register_preprocessor("feature", "cell") # NOTE: register any custom preprocessing function to be used for tuning
25+
class GaussRandProjFeature(BaseTransform):
26+
"""Custom preprocessing to extract cell feature via Gaussian random projection."""
27+
28+
_DISPLAY_ATTRS = ("n_components", "eps")
29+
30+
def __init__(self, n_components: int = 400, eps: float = 0.1, **kwargs):
31+
super().__init__(**kwargs)
32+
self.n_components = n_components
33+
self.eps = eps
34+
35+
def __call__(self, data):
36+
feat = data.get_feature(return_type="numpy")
37+
grp = GaussianRandomProjection(n_components=self.n_components, eps=self.eps)
38+
39+
self.logger.info(f"Start generateing cell feature via Gaussian random projection (d={self.n_components}).")
40+
data.data.obsm[self.out] = grp.fit_transform(feat)
41+
42+
return data
43+
44+
45+
if __name__ == "__main__":
46+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
47+
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
48+
parser.add_argument("--dense_dim", type=int, default=400, help="dim of PCA")
49+
parser.add_argument("--gpu", type=int, default=0, help="GPU id, set to -1 for CPU")
50+
parser.add_argument("--log_level", type=str, default="INFO", choices=get_args(LogLevel))
51+
parser.add_argument("--species", default="mouse")
52+
parser.add_argument("--test_dataset", nargs="+", default=[2695], type=int, help="list of dataset id")
53+
parser.add_argument("--tissue", default="Brain") # TODO: Add option for different tissue name for train/test
54+
parser.add_argument("--train_dataset", nargs="+", default=[753], type=int, help="list of dataset id")
55+
parser.add_argument("--valid_dataset", nargs="+", default=None, type=int, help="list of dataset id")
56+
parser.add_argument("--tune_mode", default="pipeline_params", choices=["pipeline", "params", "pipeline_params"])
57+
parser.add_argument("--seed", type=int, default=10)
58+
parser.add_argument("--count", type=int, default=2)
59+
parser.add_argument("--sweep_id", type=str, default=None)
60+
parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str)
61+
parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str)
62+
args = parser.parse_args()
63+
logger.setLevel(args.log_level)
64+
logger.info(f"\n{pprint.pformat(vars(args))}")
65+
file_root_path = Path(
66+
args.root_path, "_".join([
67+
"-".join([str(num) for num in dataset])
68+
for dataset in [args.train_dataset, args.valid_dataset, args.test_dataset] if dataset is not None
69+
])).resolve()
70+
logger.info(f"\n files is saved in {file_root_path}")
71+
pipeline_planer = PipelinePlaner.from_config_file(f"{file_root_path}/{args.tune_mode}_tuning_config.yaml")
72+
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "2000"
73+
74+
def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer):
75+
wandb.init(settings=wandb.Settings(start_method='thread'))
76+
77+
set_seed(args.seed)
78+
model = SVM(args, random_state=args.seed)
79+
80+
# Load raw data
81+
data = CellTypeAnnotationDataset(train_dataset=args.train_dataset, test_dataset=args.test_dataset,
82+
valid_dataset=args.valid_dataset, species=args.species, tissue=args.tissue,
83+
data_dir="../temp_data").load_data()
84+
85+
# Prepare preprocessing pipeline and apply it to data
86+
kwargs = {tune_mode: dict(wandb.config)}
87+
preprocessing_pipeline = pipeline_planer.generate(**kwargs)
88+
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
89+
preprocessing_pipeline(data)
90+
91+
# Obtain training and testing data
92+
x_train, y_train = data.get_train_data()
93+
y_train_converted = y_train.argmax(1) # convert one-hot representation into label index representation
94+
x_test, y_test = data.get_test_data()
95+
x_valid, y_valid = data.get_val_data()
96+
# Train and evaluate the model
97+
model.fit(x_train, y_train_converted)
98+
train_score = model.score(x_train, y_train)
99+
score = model.score(x_valid, y_valid)
100+
test_score = model.score(x_test, y_test)
101+
wandb.log({"train_acc": train_score, "acc": score, "test_acc": test_score})
102+
wandb.finish()
103+
gc.collect()
104+
torch.cuda.empty_cache()
105+
106+
entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
107+
evaluate_pipeline, sweep_id=args.sweep_id, count=args.count) #Score can be recorded for each epoch
108+
save_summary_data(entity, project, sweep_id, summary_file_path=args.summary_file_path, root_path=file_root_path)
109+
if args.tune_mode == "pipeline" or args.tune_mode == "pipeline_params":
110+
get_step3_yaml(result_load_path=f"{args.summary_file_path}", step2_pipeline_planer=pipeline_planer,
111+
conf_load_path=f"{Path(args.root_path).resolve().parent}/step3_default_params.yaml",
112+
root_path=file_root_path)
113+
if args.tune_mode == "pipeline_params":
114+
run_step3(file_root_path, evaluate_pipeline, tune_mode="params", step2_pipeline_planer=pipeline_planer)
115+
"""To reproduce SVM benchmarks, please refer to command lines below:
116+
117+
Mouse Brain
118+
$ python main.py --tune_mode (pipeline/params/pipeline_params) --species mouse --tissue Brain --train_dataset 753 --test_dataset 2695 --valid_dataset 3285
119+
120+
Mouse Spleen
121+
$ python main.py --tune_mode (pipeline/params/pipeline_params) --species mouse --tissue Spleen --train_dataset 1970 --test_dataset 1759 --valid_dataset 1970
122+
123+
Mouse Kidney
124+
$ python main.py --tune_mode (pipeline/params/pipeline_params) --species mouse --tissue Kidney --train_dataset 4682 --test_dataset 203 --valid_dataset 4682
125+
126+
Human Brain
127+
$ python main.py --tune_mode (pipeline/params/pipeline_params) --species human --tissue Brain --train_dataset 328 --test_dataset 138 --valid_dataset 328
128+
129+
Human Spleen
130+
$ python main.py --species human --tissue Spleen --train_dataset 3043 3777 4029 4115 4362 4657 --test_dataset 1729 2125 2184 2724 2743 --valid_dataset 3043 3777 4029 4115 4362 4657 --count 240
131+
132+
133+
main.py --species human --tissue Spleen --train_dataset 3043 3777 4029 4115 4362 4657 --test_dataset 1729 2125 2184 2724 2743 --valid_dataset 3043 3777 4029 4115 4362 4657 --count 240 --sweep_id=p1iletlj
134+
135+
"""

0 commit comments

Comments
 (0)