Skip to content

Commit 40ef977

Browse files
committed
many debugs and cleanups
1 parent 9d7ad17 commit 40ef977

23 files changed

+1189
-876
lines changed

notebooks/genelocs.ipynb

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,9 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": null,
15+
"execution_count": 2,
1616
"metadata": {},
17-
"outputs": [
18-
{
19-
"name": "stdout",
20-
"output_type": "stream",
21-
"text": [
22-
"\u001b[92m→\u001b[0m connected lamindb: jkobject/scprint_v2\n"
23-
]
24-
},
25-
{
26-
"name": "stderr",
27-
"output_type": "stream",
28-
"text": [
29-
"/home/ml4ig1/Documents code/scPRINT/.venv/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
30-
" from pkg_resources import get_distribution, DistributionNotFound\n",
31-
"/home/ml4ig1/Documents code/simpler_flash/src/simpler_flash/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n",
32-
" @custom_fwd\n",
33-
"/home/ml4ig1/Documents code/simpler_flash/src/simpler_flash/layer_norm.py:1107: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.\n",
34-
" @custom_bwd\n"
35-
]
36-
}
37-
],
17+
"outputs": [],
3818
"source": [
3919
"from scprint2.tokenizers import protein_embeddings_generator\n",
4020
"from scprint2.utils.get_seq import load_fasta_species\n",
@@ -1326,7 +1306,7 @@
13261306
"name": "python",
13271307
"nbconvert_exporter": "python",
13281308
"pygments_lexer": "ipython3",
1329-
"version": "3.11.4"
1309+
"version": "3.11.11"
13301310
}
13311311
},
13321312
"nbformat": 4,

notebooks/generate_gene_embeddings.ipynb

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,9 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": null,
21+
"execution_count": 2,
2222
"metadata": {},
23-
"outputs": [
24-
{
25-
"name": "stdout",
26-
"output_type": "stream",
27-
"text": [
28-
"\u001b[92m→\u001b[0m connected lamindb: jkobject/scprint_v2\n"
29-
]
30-
},
31-
{
32-
"name": "stderr",
33-
"output_type": "stream",
34-
"text": [
35-
"/home/ml4ig1/Documents code/scPRINT/.venv/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
36-
" from pkg_resources import get_distribution, DistributionNotFound\n",
37-
"/home/ml4ig1/Documents code/simpler_flash/src/simpler_flash/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n",
38-
" @custom_fwd\n",
39-
"/home/ml4ig1/Documents code/simpler_flash/src/simpler_flash/layer_norm.py:1107: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.\n",
40-
" @custom_bwd\n"
41-
]
42-
}
43-
],
23+
"outputs": [],
4424
"source": [
4525
"from scprint2.tokenizers import protein_embeddings_generator\n",
4626
"# from RNABERT import RNABERT\n",
@@ -8642,7 +8622,7 @@
86428622
"name": "python",
86438623
"nbconvert_exporter": "python",
86448624
"pygments_lexer": "ipython3",
8645-
"version": "3.11.4"
8625+
"version": "3.11.11"
86468626
}
86478627
},
86488628
"nbformat": 4,

notebooks/prepare_scprint2.ipynb

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,12 @@
1313
},
1414
{
1515
"cell_type": "code",
16-
"execution_count": 2,
16+
"execution_count": 1,
1717
"id": "636ee56e",
1818
"metadata": {},
19-
"outputs": [
20-
{
21-
"name": "stdout",
22-
"output_type": "stream",
23-
"text": [
24-
"\u001b[92m→\u001b[0m connected lamindb: jkobject/scprint2\n"
25-
]
26-
}
27-
],
19+
"outputs": [],
2820
"source": [
29-
"from scdataloader.utils import _adding_scbasecamp_genes, populate_my_ontology\n",
21+
"from scdataloader.utils import _adding_scbasecamp_genes, populate_my_ontology, load_genes\n",
3022
"import os.path\n",
3123
"import urllib.request\n",
3224
"import torch\n",
@@ -35,6 +27,14 @@
3527
"%autoreload 2"
3628
]
3729
},
30+
{
31+
"cell_type": "markdown",
32+
"id": "30e835d1",
33+
"metadata": {},
34+
"source": [
35+
"# prepare ontologies (just once)\n"
36+
]
37+
},
3838
{
3939
"cell_type": "code",
4040
"execution_count": null,
@@ -56,28 +56,12 @@
5656
"_adding_scbasecamp_genes()"
5757
]
5858
},
59-
{
60-
"cell_type": "code",
61-
"execution_count": null,
62-
"id": "ef93bb4e",
63-
"metadata": {},
64-
"outputs": [],
65-
"source": [
66-
"LOC = \"../../models/\" # \"../../../\"\n",
67-
"ckpt_path = os.path.join(LOC, \"18hebyht-final-small.ckpt\")\n",
68-
"if not os.path.exists(ckpt_path):\n",
69-
" url = (\n",
70-
" \"https://huggingface.co/jkobject/scPRINT/resolve/main/18hebyht-final-small.ckpt\"\n",
71-
" )\n",
72-
" urllib.request.urlretrieve(url, ckpt_path)"
73-
]
74-
},
7559
{
7660
"cell_type": "markdown",
7761
"id": "7b80be7f",
7862
"metadata": {},
7963
"source": [
80-
"# OPTIONAL: solving potential issues\n",
64+
"# [Optional, only for some models]: solving potential issues\n",
8165
"\n",
8266
"some models have additional elements that are not used anymore and can cause an\n",
8367
"issue when loading them on specific version of pytorch lightning and scPRINT or\n",
@@ -142,13 +126,64 @@
142126
"torch.save(m, model_checkpoint_file)"
143127
]
144128
},
129+
{
130+
"cell_type": "markdown",
131+
"id": "3a9a9e15",
132+
"metadata": {},
133+
"source": [
134+
"# loading the model\n",
135+
"\n",
136+
"this is what you will need to do every time you load a model\n"
137+
]
138+
},
145139
{
146140
"cell_type": "code",
147141
"execution_count": null,
148-
"id": "3a9a9e15",
142+
"id": "fdded332",
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"LOC = \"../../models/\" # \"../../../\"\n",
147+
"ckpt_path = os.path.join(LOC, \"small-v2.ckpt\")\n",
148+
"if not os.path.exists(ckpt_path):\n",
149+
" url = \"https://huggingface.co/jkobject/scPRINT/resolve/main/small-v2.ckpt\"\n",
150+
" urllib.request.urlretrieve(url, ckpt_path)"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"id": "78714b6a",
149157
"metadata": {},
150158
"outputs": [],
151-
"source": []
159+
"source": [
160+
"model = scPRINT2.load_from_checkpoint(\n",
161+
" ckpt_path,\n",
162+
" precpt_gene_emb=None,\n",
163+
" gene_pos_file=None,\n",
164+
")\n",
165+
"if not torch.cuda.is_available():\n",
166+
" model = model.to(torch.float32)\n",
167+
"\n",
168+
"model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"id": "1a22c4eb",
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"# in some cases the gene ontology has changed too much since I trained the model,\n",
179+
"# so I need to remove the genes that are not in the ontology anymore from the model\n",
180+
"missing = set(model.genes) - set(load_genes(model.organisms).index)\n",
181+
"if len(missing) > 0:\n",
182+
" print(\n",
183+
" \"Warning: some genes missmatch exist between model and ontology: solving...\",\n",
184+
" )\n",
185+
" model._rm_genes(missing)"
186+
]
152187
}
153188
],
154189
"metadata": {

notebooks/scPRINT-2-repro-notebooks/batch_corr_op ft.ipynb

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 1,
22+
"execution_count": null,
2323
"id": "3b0caefc",
2424
"metadata": {
2525
"execution": {
@@ -64,6 +64,7 @@
6464
"import scanpy as sc\n",
6565
"from scprint2 import scPRINT2\n",
6666
"from scdataloader import Preprocessor\n",
67+
"from scdataloader.utils import load_genes\n",
6768
"from scprint2.tasks import Embedder, FinetuneBatchClass\n",
6869
"from scprint2.tasks.cell_emb import compute_classification\n",
6970
"from scprint2.utils import zero_shot_annotation_with_refinement\n",
@@ -108,6 +109,14 @@
108109
"! uv pip list | grep scib #same version as OP"
109110
]
110111
},
112+
{
113+
"cell_type": "markdown",
114+
"id": "a9d5fef3",
115+
"metadata": {},
116+
"source": [
117+
"## prepare the data\n"
118+
]
119+
},
111120
{
112121
"cell_type": "code",
113122
"execution_count": 3,
@@ -1031,9 +1040,17 @@
10311040
"pd.DataFrame(res[\"cellxgene_census/dkd\"])"
10321041
]
10331042
},
1043+
{
1044+
"cell_type": "markdown",
1045+
"id": "a1c45fa8",
1046+
"metadata": {},
1047+
"source": [
1048+
"## loading the model\n"
1049+
]
1050+
},
10341051
{
10351052
"cell_type": "code",
1036-
"execution_count": 9,
1053+
"execution_count": null,
10371054
"id": "cd8ca41a",
10381055
"metadata": {
10391056
"execution": {
@@ -1061,12 +1078,25 @@
10611078
}
10621079
],
10631080
"source": [
1064-
"model_checkpoint_file = \"../models/18hebyht-final-small.ckpt\"\n",
1081+
"LOC2 = \"../../models/\" # \"../../../\"\n",
1082+
"ckpt_path = os.path.join(LOC2, \"small-v2.ckpt\")\n",
1083+
"if not os.path.exists(ckpt_path):\n",
1084+
" url = \"https://huggingface.co/jkobject/scPRINT/resolve/main/small-v2.ckpt\"\n",
1085+
" urllib.request.urlretrieve(url, ckpt_path)\n",
10651086
"\n",
10661087
"model = scPRINT2.load_from_checkpoint(\n",
1067-
" model_checkpoint_file, precpt_gene_emb=None, gene_pos_file=None\n",
1088+
" ckpt_path, precpt_gene_emb=None, gene_pos_file=None\n",
10681089
")\n",
1069-
"model = model.to(\"cuda\")"
1090+
"if not torch.cuda.is_available():\n",
1091+
" model = model.to(torch.float32)\n",
1092+
"\n",
1093+
"model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1094+
"missing = set(model.genes) - set(load_genes(model.organisms).index)\n",
1095+
"if len(missing) > 0:\n",
1096+
" print(\n",
1097+
" \"Warning: some genes missmatch exist between model and ontology: solving...\",\n",
1098+
" )\n",
1099+
" model._rm_genes(missing)"
10701100
]
10711101
},
10721102
{

0 commit comments

Comments
 (0)