|
9 | 9 | import os |
10 | 10 |
|
11 | 11 | ## VIASH START |
12 | | -# Note: this section is auto-generated by viash at runtime. To edit it, make changes |
13 | | -# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. |
14 | 12 | par = { |
15 | | - 'input': 'resources_test/.../input.h5ad', |
16 | | - 'output': 'output.h5ad', |
17 | | - "model": "large", |
18 | | -} |
19 | | -meta = { |
20 | | - 'name': 'scprint' |
| 13 | + "input": "resources_test/.../input.h5ad", |
| 14 | + "output": "output.h5ad", |
| 15 | + "model": "large", |
21 | 16 | } |
| 17 | +meta = {"name": "scprint"} |
22 | 18 | ## VIASH END |
23 | 19 |
|
24 | 20 | sys.path.append(meta["resources_dir"]) |
|
33 | 29 | elif input.uns["dataset_organism"] == "mus_musculus": |
34 | 30 | input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" |
35 | 31 | else: |
36 | | - raise ValueError(f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'") |
| 32 | + raise ValueError( |
| 33 | + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" |
| 34 | + ) |
37 | 35 | adata = input.copy() |
38 | 36 |
|
39 | | -print('\n>>> Preprocessing data...', flush=True) |
| 37 | +print("\n>>> Preprocessing data...", flush=True) |
40 | 38 | preprocessor = Preprocessor( |
41 | 39 | # Lower this threshold for test datasets |
42 | | - min_valid_genes_id = 1000 if input.n_vars < 2000 else 10000, |
| 40 | + min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, |
43 | 41 | # Turn off cell filtering to return results for all cells |
44 | | - filter_cell_by_counts = False, |
45 | | - min_nnz_genes = False, |
| 42 | + filter_cell_by_counts=False, |
| 43 | + min_nnz_genes=False, |
46 | 44 | do_postp=False, |
47 | 45 | # Skip ontology checks |
48 | | - skip_validate=True |
| 46 | + skip_validate=True, |
49 | 47 | ) |
50 | 48 | adata = preprocessor(adata) |
51 | 49 |
|
52 | 50 | print(f"\n>>> Downloading '{par['model']}' model...", flush=True) |
53 | 51 | model_checkpoint_file = hf_hub_download( |
54 | | - repo_id="jkobject/scPRINT", |
55 | | - filename=f"{par['model']}.ckpt" |
| 52 | + repo_id="jkobject/scPRINT", filename=f"{par['model']}.ckpt" |
56 | 53 | ) |
57 | 54 | print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) |
58 | 55 | model = scPrint.load_from_checkpoint( |
59 | 56 | model_checkpoint_file, |
60 | | - transformer = "normal", # Don't use this for GPUs with flashattention |
61 | | - precpt_gene_emb = None |
| 57 | + transformer="normal", # Don't use this for GPUs with flashattention |
| 58 | + precpt_gene_emb=None, |
62 | 59 | ) |
63 | 60 |
|
64 | | -print('\n>>> Embedding data...', flush=True) |
| 61 | +print("\n>>> Embedding data...", flush=True) |
65 | 62 | if torch.cuda.is_available(): |
66 | 63 | print("CUDA is available, using GPU", flush=True) |
67 | 64 | precision = "16" |
|
77 | 74 | max_len=4000, |
78 | 75 | add_zero_genes=0, |
79 | 76 | num_workers=n_cores_available, |
80 | | - doclass = False, |
81 | | - doplot = False, |
82 | | - precision = precision, |
83 | | - dtype = dtype, |
| 77 | + doclass=False, |
| 78 | + doplot=False, |
| 79 | + precision=precision, |
| 80 | + dtype=dtype, |
84 | 81 | ) |
85 | 82 | embedded, _ = embedder(model, adata, cache=False) |
86 | 83 |
|
|
0 commit comments