|
1 | 1 | from segger.data.io import XeniumSample |
2 | 2 | from segger.training.train import LitSegger |
3 | 3 | from segger.training.segger_data_module import SeggerDataModule |
4 | | -from segger.prediction.predict import predict, load_model |
| 4 | +# from segger.prediction.predict import predict, load_model |
| 5 | +from segger.models.segger_model import Segger |
| 6 | +from segger.training.train import LitSegger |
| 7 | +from torch_geometric.nn import to_hetero |
5 | 8 | from lightning.pytorch.loggers import CSVLogger |
6 | | -from pytorch_lightning import Trainer |
| 9 | +from lightning import Trainer |
7 | 10 | from pathlib import Path |
8 | 11 | from lightning.pytorch.plugins.environments import LightningEnvironment |
9 | 12 | from matplotlib import pyplot as plt |
10 | 13 | import seaborn as sns |
11 | | - |
12 | 14 | # import pandas as pd |
13 | 15 | from segger.data.utils import calculate_gene_celltype_abundance_embedding |
14 | | -import scanpy as sc |
| 16 | +# import scanpy as sc |
15 | 17 | import os |
| 18 | +from lightning import LightningModule |
| 19 | + |
16 | 20 |
|
17 | 21 |
|
18 | | -segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_final_200") |
19 | | -models_dir = Path("./models/bc_rep1_emb_final_200") |
| 22 | +segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/cosmx_pancreas") |
| 23 | +models_dir = Path("./models/cosmx_pancreas") |
20 | 24 |
|
21 | 25 | # Base directory to store Pytorch Lightning models |
22 | 26 | # models_dir = Path('models') |
|
35 | 39 |
|
36 | 40 | # If you use custom gene embeddings, use the following two lines instead: |
37 | 41 | is_token_based = False |
38 | | -num_tx_tokens = ( |
39 | | - dm.train[0].x_dict["tx"].shape[1] |
40 | | -) # Set the number of tokens to the number of genes |
| 42 | +# num_tx_tokens = ( |
| 43 | +# dm.train[0].x_dict["tx"].shape[1] |
| 44 | +# ) # Set the number of tokens to the number of genes |
41 | 45 |
|
42 | 46 |
|
43 | | -num_bd_features = dm.train[0].x_dict["bd"].shape[1] |
44 | | - |
45 | | -# Initialize the Lightning model |
46 | | -ls = LitSegger( |
47 | | - is_token_based=is_token_based, |
48 | | - num_node_features={"tx": num_tx_tokens, "bd": num_bd_features}, |
| 47 | +model = Segger( |
| 48 | + # is_token_based=is_token_based, |
| 49 | + num_tx_tokens= 25000, |
49 | 50 | init_emb=8, |
50 | 51 | hidden_channels=64, |
51 | 52 | out_channels=16, |
52 | 53 | heads=4, |
53 | 54 | num_mid_layers=3, |
54 | | - aggr="sum", |
55 | | - learning_rate=1e-3, |
56 | 55 | ) |
| 56 | +model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") |
| 57 | + |
| 58 | +batch = dm.train[0] |
| 59 | +model.forward(batch.x_dict, batch.edge_index_dict) |
| 60 | +# Wrap the model in LitSegger |
| 61 | +ls = LitSegger(model=model) |
| 62 | + |
| 63 | +# # Initialize the Lightning model |
| 64 | +# ls = LitSegger( |
| 65 | +# # is_token_based=is_token_based, |
| 66 | +# num_tx_tokens= 7000, |
| 67 | +# init_emb=8, |
| 68 | +# hidden_channels=64, |
| 69 | +# out_channels=16, |
| 70 | +# heads=4, |
| 71 | +# num_mid_layers=3, |
| 72 | +# ) |
57 | 73 |
|
58 | 74 | # Initialize the Lightning trainer |
59 | 75 | trainer = Trainer( |
60 | | - accelerator="cuda", |
| 76 | + accelerator="cpu", |
61 | 77 | strategy="auto", |
62 | 78 | precision="16-mixed", |
63 | 79 | devices=2, # set higher number if more gpus are available |
|
67 | 83 | ) |
68 | 84 |
|
69 | 85 |
|
70 | | -trainer.fit(model=ls, datamodule=dm) |
| 86 | +trainer.fit(ls , datamodule=dm) |
0 commit comments