Skip to content

Commit 56e7064

Browse files
committed
adapted the tutorial for the new model
1 parent ffdd9e0 commit 56e7064

File tree

2 files changed

+46
-30
lines changed

2 files changed

+46
-30
lines changed

docs/notebooks/segger_tutorial.ipynb

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,27 +283,27 @@
283283
"\n",
284284
"dm.setup()\n",
285285
"\n",
286-
"is_token_based = True\n",
287286
"num_tx_tokens = 500\n",
288287
"\n",
289288
"# If you use custom gene embeddings, use the following two lines instead:\n",
290-
"# is_token_based = False\n",
291289
"# num_tx_tokens = dm.train[0].x_dict[\"tx\"].shape[1] # Set the number of tokens to the number of genes\n",
292290
"\n",
293291
"\n",
294-
"num_bd_features = dm.train[0].x_dict[\"bd\"].shape[1]\n",
295-
"\n",
296-
"# Initialize the Lightning model\n",
297-
"ls = LitSegger(\n",
298-
" is_token_based = is_token_based,\n",
299-
" num_node_features = {\"tx\": num_tx_tokens, \"bd\": num_bd_features},\n",
300-
" init_emb=8, \n",
292+
"model = Segger(\n",
293+
" # is_token_based=is_token_based,\n",
294+
" num_tx_tokens=num_tx_tokens,\n",
295+
" init_emb=8,\n",
301296
" hidden_channels=64,\n",
302297
" out_channels=16,\n",
303298
" heads=4,\n",
304-
" num_mid_layers=1,\n",
305-
" aggr='sum',\n",
299+
" num_mid_layers=3,\n",
306300
")\n",
301+
"model = to_hetero(model, ([\"tx\", \"bd\"], [(\"tx\", \"belongs\", \"bd\"), (\"tx\", \"neighbors\", \"tx\")]), aggr=\"sum\")\n",
302+
"\n",
303+
"batch = dm.train[0]\n",
304+
"model.forward(batch.x_dict, batch.edge_index_dict)\n",
305+
"# Wrap the model in LitSegger\n",
306+
"ls = LitSegger(model=model)\n",
307307
"\n",
308308
"# Initialize the Lightning trainer\n",
309309
"trainer = Trainer(\n",

scripts/train_model_sample.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
from segger.data.io import XeniumSample
22
from segger.training.train import LitSegger
33
from segger.training.segger_data_module import SeggerDataModule
4-
from segger.prediction.predict import predict, load_model
4+
# from segger.prediction.predict import predict, load_model
5+
from segger.models.segger_model import Segger
6+
from segger.training.train import LitSegger
7+
from torch_geometric.nn import to_hetero
58
from lightning.pytorch.loggers import CSVLogger
6-
from pytorch_lightning import Trainer
9+
from lightning import Trainer
710
from pathlib import Path
811
from lightning.pytorch.plugins.environments import LightningEnvironment
912
from matplotlib import pyplot as plt
1013
import seaborn as sns
11-
1214
# import pandas as pd
1315
from segger.data.utils import calculate_gene_celltype_abundance_embedding
14-
import scanpy as sc
16+
# import scanpy as sc
1517
import os
18+
from lightning import LightningModule
19+
1620

1721

18-
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_final_200")
19-
models_dir = Path("./models/bc_rep1_emb_final_200")
22+
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/cosmx_pancreas")
23+
models_dir = Path("./models/cosmx_pancreas")
2024

2125
# Base directory to store Pytorch Lightning models
2226
# models_dir = Path('models')
@@ -35,29 +39,41 @@
3539

3640
# If you use custom gene embeddings, use the following two lines instead:
3741
is_token_based = False
38-
num_tx_tokens = (
39-
dm.train[0].x_dict["tx"].shape[1]
40-
) # Set the number of tokens to the number of genes
42+
# num_tx_tokens = (
43+
# dm.train[0].x_dict["tx"].shape[1]
44+
# ) # Set the number of tokens to the number of genes
4145

4246

43-
num_bd_features = dm.train[0].x_dict["bd"].shape[1]
44-
45-
# Initialize the Lightning model
46-
ls = LitSegger(
47-
is_token_based=is_token_based,
48-
num_node_features={"tx": num_tx_tokens, "bd": num_bd_features},
47+
model = Segger(
48+
# is_token_based=is_token_based,
49+
num_tx_tokens= 25000,
4950
init_emb=8,
5051
hidden_channels=64,
5152
out_channels=16,
5253
heads=4,
5354
num_mid_layers=3,
54-
aggr="sum",
55-
learning_rate=1e-3,
5655
)
56+
model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum")
57+
58+
batch = dm.train[0]
59+
model.forward(batch.x_dict, batch.edge_index_dict)
60+
# Wrap the model in LitSegger
61+
ls = LitSegger(model=model)
62+
63+
# # Initialize the Lightning model
64+
# ls = LitSegger(
65+
# # is_token_based=is_token_based,
66+
# num_tx_tokens= 7000,
67+
# init_emb=8,
68+
# hidden_channels=64,
69+
# out_channels=16,
70+
# heads=4,
71+
# num_mid_layers=3,
72+
# )
5773

5874
# Initialize the Lightning trainer
5975
trainer = Trainer(
60-
accelerator="cuda",
76+
accelerator="cpu",
6177
strategy="auto",
6278
precision="16-mixed",
6379
devices=2, # set higher number if more gpus are available
@@ -67,4 +83,4 @@
6783
)
6884

6985

70-
trainer.fit(model=ls, datamodule=dm)
86+
trainer.fit(ls , datamodule=dm)

0 commit comments

Comments
 (0)