Skip to content

Commit c3ce8df

Browse files
committed
style code
1 parent 4774c98 commit c3ce8df

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

src/methods/scprint/script.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,12 @@
99
import os
1010

1111
## VIASH START
12-
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
13-
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
1412
par = {
15-
'input': 'resources_test/.../input.h5ad',
16-
'output': 'output.h5ad',
17-
"model": "large",
18-
}
19-
meta = {
20-
'name': 'scprint'
13+
"input": "resources_test/.../input.h5ad",
14+
"output": "output.h5ad",
15+
"model": "large",
2116
}
17+
meta = {"name": "scprint"}
2218
## VIASH END
2319

2420
sys.path.append(meta["resources_dir"])
@@ -33,35 +29,36 @@
3329
elif input.uns["dataset_organism"] == "mus_musculus":
3430
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
3531
else:
36-
raise ValueError(f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'")
32+
raise ValueError(
33+
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
34+
)
3735
adata = input.copy()
3836

39-
print('\n>>> Preprocessing data...', flush=True)
37+
print("\n>>> Preprocessing data...", flush=True)
4038
preprocessor = Preprocessor(
4139
# Lower this threshold for test datasets
42-
min_valid_genes_id = 1000 if input.n_vars < 2000 else 10000,
40+
min_valid_genes_id=1000 if input.n_vars < 2000 else 10000,
4341
# Turn off cell filtering to return results for all cells
44-
filter_cell_by_counts = False,
45-
min_nnz_genes = False,
42+
filter_cell_by_counts=False,
43+
min_nnz_genes=False,
4644
do_postp=False,
4745
# Skip ontology checks
48-
skip_validate=True
46+
skip_validate=True,
4947
)
5048
adata = preprocessor(adata)
5149

5250
print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
5351
model_checkpoint_file = hf_hub_download(
54-
repo_id="jkobject/scPRINT",
55-
filename=f"{par['model']}.ckpt"
52+
repo_id="jkobject/scPRINT", filename=f"{par['model']}.ckpt"
5653
)
5754
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
5855
model = scPrint.load_from_checkpoint(
5956
model_checkpoint_file,
60-
transformer = "normal", # Don't use this for GPUs with flashattention
61-
precpt_gene_emb = None
57+
transformer="normal", # Don't use this for GPUs with flashattention
58+
precpt_gene_emb=None,
6259
)
6360

64-
print('\n>>> Embedding data...', flush=True)
61+
print("\n>>> Embedding data...", flush=True)
6562
if torch.cuda.is_available():
6663
print("CUDA is available, using GPU", flush=True)
6764
precision = "16"
@@ -77,10 +74,10 @@
7774
max_len=4000,
7875
add_zero_genes=0,
7976
num_workers=n_cores_available,
80-
doclass = False,
81-
doplot = False,
82-
precision = precision,
83-
dtype = dtype,
77+
doclass=False,
78+
doplot=False,
79+
precision=precision,
80+
dtype=dtype,
8481
)
8582
embedded, _ = embedder(model, adata, cache=False)
8683

0 commit comments

Comments
 (0)