|
| 1 | +import time |
| 2 | + |
1 | 3 | import anndata as ad |
2 | 4 | import numpy as np |
| 5 | +import scvi |
| 6 | +import torch |
3 | 7 | from scvi.external import cytovi |
4 | | -from sklearn.cluster import KMeans |
5 | | -from threadpoolctl import threadpool_limits |
| 8 | + |
| 9 | +# from sklearn.cluster import KMeans |
| 10 | +# from threadpoolctl import threadpool_limits |
6 | 11 |
|
7 | 12 | ## VIASH START |
8 | 13 | par = { |
9 | 14 | "input": "resources_test/task_cyto_batch_integration/mouse_spleen_flow_cytometry_subset/censored_split2.h5ad", |
10 | 15 | "output": "resources_test/task_cyto_batch_integration/mouse_spleen_flow_cytometry_subset/output_cytovi_split2.h5ad", |
11 | 16 | "n_hidden": 128, |
12 | 17 | "n_layers": 1, |
13 | | - "n_clusters": 10, |
14 | | - "subsample_fraction": 0.5, |
| 18 | + "max_epochs": 1000, |
| 19 | + "train_size": 0.9, |
15 | 20 | } |
16 | 21 | meta = {"name": "cytovi"} |
17 | 22 | ## VIASH END |
18 | 23 |
|
| 24 | +# setting calculation to TF32 to speed up training |
| 25 | +torch.backends.cuda.matmul.allow_tf32 = True |
| 26 | + |
| 27 | +# increase num workers for data loading |
| 28 | +scvi.settings.num_workers = 95 |
| 29 | + |
19 | 30 | print("Reading and preparing input files", flush=True) |
20 | 31 | adata = ad.read_h5ad(par["input"]) |
21 | 32 |
|
22 | 33 | adata.obs["batch_str"] = adata.obs["batch"].astype(str) |
| 34 | +adata.obs["sample_key_str"] = adata.obs["sample"].astype(str) |
23 | 35 |
|
24 | 36 | markers_to_correct = adata.var[adata.var["to_correct"]].index.to_numpy() |
25 | 37 | markers_not_correct = adata.var[~adata.var["to_correct"]].index.to_numpy() |
|
33 | 45 | adata=adata_to_correct, |
34 | 46 | transformed_layer_key="preprocessed", |
35 | 47 | batch_key="batch_str", |
| 48 | + scaled_layer_key="scaled", |
36 | 49 | inplace=True, |
37 | 50 | ) |
38 | 51 |
|
39 | | -print("Clustering using k-means with k =", par["n_clusters"], flush=True) |
40 | | -# cluster data using Kmeans |
41 | | -with threadpool_limits(limits=1): |
42 | | - adata_to_correct.obs["clusters"] = ( |
43 | | - KMeans(n_clusters=par["n_clusters"], random_state=0) |
44 | | - .fit_predict(adata_to_correct.layers["scaled"]) |
45 | | - .astype(str) |
46 | | - ) |
47 | | -# concatenate obs so we can use it for subsampling |
48 | | -adata_to_correct.obs["sample_cluster"] = ( |
49 | | - adata_to_correct.obs["sample"].astype(str) + "_" + adata_to_correct.obs["clusters"] |
50 | | -) |
51 | | -# subsample cells without replacement |
52 | | -print("Subsampling cells", flush=True) |
53 | | -subsampled_cells = adata_to_correct.obs.groupby("sample_cluster")[ |
54 | | - "sample_cluster" |
55 | | -].apply(lambda x: x.sample(n=round(len(x) * par["subsample_fraction"]), replace=False)) |
56 | | -# need the cell id included in the subsample |
57 | | -subsampled_cells_idx = [x[1] for x in subsampled_cells.index.to_list()] |
58 | | - |
59 | | -adata_subsampled = adata_to_correct[subsampled_cells_idx, :].copy() |
60 | | - |
61 | 52 | print( |
62 | | - f"Train CytoVI on subsampled data containing {adata_subsampled.shape[0]} cells", |
| 53 | + f"Train CytoVI on {adata_to_correct.shape[0]} cells", |
63 | 54 | flush=True, |
64 | 55 | ) |
65 | 56 |
|
66 | | -cytovi.CYTOVI.setup_anndata(adata_subsampled, layer="scaled", batch_key="batch_str") |
| 57 | +cytovi.CYTOVI.setup_anndata( |
| 58 | + adata_to_correct, |
| 59 | + layer="scaled", |
| 60 | + batch_key="batch_str", |
| 61 | + sample_key="sample_key_str", |
| 62 | +) |
| 63 | + |
67 | 64 | model = cytovi.CYTOVI( |
68 | | - adata=adata_subsampled, n_hidden=par["n_hidden"], n_layers=par["n_layers"] |
| 65 | + adata_to_correct, n_hidden=par["n_hidden"], n_layers=par["n_layers"] |
| 66 | +) |
| 67 | + |
| 68 | +print("Start training CytoVI model", flush=True) |
| 69 | + |
| 70 | +start = time.time() |
| 71 | +model.train( |
| 72 | + batch_size=8192, |
| 73 | + max_epochs=par["max_epochs"], |
| 74 | + train_size=par["train_size"], |
69 | 75 | ) |
70 | | -model.train() |
| 76 | +end = time.time() |
| 77 | +print(f"Training took {end - start:.2f} seconds", flush=True) |
71 | 78 |
|
72 | 79 | # get batch corrected data |
73 | 80 | print("Correcting data", flush=True) |
|
0 commit comments