Skip to content

Commit a4d0116

Browse files
committed
fixes 5k convergence, see scripts for examples
1 parent 0374cd4 commit a4d0116

File tree

10 files changed

+1275
-2
lines changed

10 files changed

+1275
-2
lines changed

.gitignore

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,7 @@ dev*
178178
*_old*
179179
.dev
180180

181-
scripts/*
182-
.scripts/*
181+
# scripts/*
182+
.scripts/*
183+
184+
models/*
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from segger.data.parquet.sample import STSampleParquet
2+
from path import Path
3+
from segger.data.utils import calculate_gene_celltype_abundance_embedding
4+
import scanpy as sc
5+
import pandas as pd
6+
import math
7+
import numpy as np
8+
from segger.data.parquet._utils import get_polygons_from_xy
9+
10+
"""
11+
This script preprocesses Xenium spatial transcriptomics data for SEGGER cell segmentation model.
12+
13+
14+
Parameters are set properly for a 5K panel.
15+
16+
Key steps:
17+
1. Data Loading:
18+
- Loads scRNA-seq reference data to create gene-celltype embeddings
19+
- Imports Xenium transcripts and nucleus boundaries
20+
21+
2. Parameter Optimization:
22+
- Calculates optimal neighborhood parameters based on tissue characteristics
23+
- dist_tx: Sets transcript neighbor search radius to 1/4 of typical nucleus size
24+
- k_tx: Determines number of transcripts to sample based on local density
25+
26+
3. Dataset Creation:
27+
- Filters transcripts to those overlapping nuclei
28+
- Creates graph connections between nearby transcripts
29+
- Splits data into training/validation sets
30+
- Saves in PyG format for SEGGER training
31+
32+
Usage:
33+
- Input: Raw Xenium data (transcripts.parquet, nucleus_boundaries.parquet)
34+
- Output: Processed dataset with graph structure and embeddings
35+
"""
36+
37+
38+
39+
XENIUM_DATA_DIR = Path(
40+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real"
41+
)
42+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei")
43+
SCRNASEQ_FILE = Path(
44+
"data_tidy/Human_CRC/scRNAseq.h5ad"
45+
)
46+
CELLTYPE_COLUMN = "Level1" # change this to your column name
47+
scrnaseq = sc.read(SCRNASEQ_FILE)
48+
49+
50+
51+
# subsample the scRNAseq if needed
52+
# sc.pp.subsample(scrnaseq, 0.1)
53+
# scrnaseq.var_names_make_unique()
54+
55+
56+
# Calculate gene-celltype embeddings from reference data
57+
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
58+
scrnaseq,
59+
CELLTYPE_COLUMN
60+
)
61+
62+
# Initialize spatial transcriptomics sample object
63+
sample = STSampleParquet(
64+
base_dir=XENIUM_DATA_DIR,
65+
n_workers=4,
66+
sample_type="xenium",
67+
weights=gene_celltype_abundance_embedding
68+
)
69+
70+
71+
sample.save(
72+
data_dir=SEGGER_DATA_DIR,
73+
k_bd=3, # Number of boundary points to connect
74+
dist_bd=10, # Maximum distance for boundary connections
75+
k_tx=5, # Use calculated optimal transcript neighbors
76+
dist_tx=5, # Use calculated optimal search radius
77+
tile_width=50, # Tile size for processing
78+
tile_height=50,
79+
neg_sampling_ratio=10., # 5:1 negative:positive samples
80+
frac=1.0, # Use all data
81+
val_prob=0.3, # 30% validation set
82+
test_prob=0, # No test set
83+
)

scripts/1_train_5k.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from segger.training.segger_data_module import SeggerDataModule
2+
# from segger.prediction.predict import predict, load_model
3+
from segger.models.segger_model import Segger
4+
from segger.training.train import LitSegger
5+
from torch_geometric.nn import to_hetero
6+
from lightning.pytorch.loggers import CSVLogger
7+
from lightning import Trainer
8+
from pathlib import Path
9+
from lightning.pytorch.plugins.environments import LightningEnvironment
10+
from matplotlib import pyplot as plt
11+
import seaborn as sns
12+
# import pandas as pd
13+
from segger.data.utils import calculate_gene_celltype_abundance_embedding
14+
# import scanpy as sc
15+
import os
16+
from lightning import LightningModule
17+
18+
19+
20+
segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_cells")
21+
models_dir = Path("./models/human_CRC_seg_cells")
22+
23+
# Base directory to store Pytorch Lightning models
24+
# models_dir = Path('models')
25+
26+
# Initialize the Lightning data module
27+
dm = SeggerDataModule(
28+
data_dir=segger_data_dir,
29+
batch_size=2,
30+
num_workers=2,
31+
)
32+
33+
dm.setup()
34+
35+
# is_token_based = True
36+
# num_tx_tokens = 500
37+
38+
# If you use custom gene embeddings, use the following two lines instead:
39+
is_token_based = False
40+
num_tx_tokens = (
41+
dm.train[0].x_dict["tx"].shape[1]
42+
) # Set the number of tokens to the number of genes
43+
44+
45+
model = Segger(
46+
num_tx_tokens= num_tx_tokens,
47+
init_emb=8,
48+
hidden_channels=32,
49+
out_channels=16,
50+
heads=4,
51+
num_mid_layers=3,
52+
)
53+
model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum")
54+
55+
batch = dm.train[0]
56+
model.forward(batch.x_dict, batch.edge_index_dict)
57+
# Wrap the model in LitSegger
58+
ls = LitSegger(model=model)
59+
60+
61+
# Initialize the Lightning trainer
62+
trainer = Trainer(
63+
accelerator="gpu",
64+
strategy="auto",
65+
precision="16-mixed",
66+
devices=4, # set higher number if more gpus are available
67+
max_epochs=150,
68+
default_root_dir=models_dir,
69+
logger=CSVLogger(models_dir),
70+
)
71+
72+
73+
trainer.fit(ls, datamodule=dm)

scripts/2_predict_5k.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from segger.training.segger_data_module import SeggerDataModule
2+
from segger.prediction.predict_parquet import segment, load_model
3+
from pathlib import Path
4+
from matplotlib import pyplot as plt
5+
import seaborn as sns
6+
import scanpy as sc
7+
import os
8+
import dask.dataframe as dd
9+
import pandas as pd
10+
from pathlib import Path
11+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
12+
os.environ["CUPY_CACHE_DIR"] = "./.cupy"
13+
14+
15+
XENIUM_DATA_DIR = Path( #raw data dir
16+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real"
17+
)
18+
transcripts_file = (
19+
XENIUM_DATA_DIR / "transcripts.parquet"
20+
)
21+
22+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") # preprocessed data dir
23+
24+
25+
seg_tag = "human_CRC_seg_nuclei"
26+
model_version = 0
27+
models_dir = Path("./models") / seg_tag #trained model dir
28+
29+
30+
output_dir = Path( #output dir
31+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_nuclei"
32+
)
33+
34+
35+
# Initialize the Lightning data module
36+
dm = SeggerDataModule(
37+
data_dir=SEGGER_DATA_DIR,
38+
batch_size=1,
39+
num_workers=1,
40+
)
41+
42+
dm.setup()
43+
44+
45+
# Load in latest checkpoint
46+
model_path = models_dir / "lightning_logs" / f"version_{model_version}"
47+
model = load_model(model_path / "checkpoints")
48+
49+
receptive_field = {"k_bd": 4, "dist_bd": 15, "k_tx": 5, "dist_tx": 3}
50+
51+
segment(
52+
model,
53+
dm,
54+
save_dir=output_dir,
55+
seg_tag=seg_tag,
56+
transcript_file=transcripts_file,
57+
receptive_field=receptive_field,
58+
min_transcripts=5,
59+
score_cut=0.5,
60+
cell_id_col="segger_cell_id",
61+
save_transcripts= True,
62+
save_anndata= True,
63+
save_cell_masks= False, # Placeholder for future implementation
64+
use_cc=False, # if one wants fragments (groups of similar transcripts not attached to any nuclei)
65+
knn_method="kd_tree",
66+
verbose=True,
67+
gpu_ids=["0"],
68+
# client=client
69+
)

scripts/create_data_merscope.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from segger.data.parquet.sample import STSampleParquet, STInMemoryDataset
2+
from path import Path
3+
from segger.data.utils import calculate_gene_celltype_abundance_embedding
4+
import scanpy as sc
5+
import pandas as pd
6+
import math
7+
import numpy as np
8+
from segger.data.parquet._utils import get_polygons_from_xy
9+
10+
"""
11+
This script preprocesses MERSCOPE spatial transcriptomics data for SEGGER cell segmentation model.
12+
13+
Key steps:
14+
1. Data Loading:
15+
- Loads scRNA-seq reference data to create gene-celltype embeddings
16+
- Imports MERSCOPE transcripts and nucleus boundaries
17+
18+
2. Parameter Optimization:
19+
- Calculates optimal neighborhood parameters based on tissue characteristics
20+
- dist_tx: Sets transcript neighbor search radius to 1/4 of typical nucleus size
21+
- k_tx: Determines number of transcripts to sample based on local density
22+
23+
3. Dataset Creation:
24+
- Filters transcripts to those overlapping nuclei
25+
- Creates graph connections between nearby transcripts
26+
- Splits data into training/validation sets
27+
- Saves in PyG format for SEGGER training
28+
29+
Usage:
30+
- Input: Raw MERSCOPE data (transcripts.parquet, nucleus_boundaries.parquet)
31+
- Output: Processed dataset with graph structure and embeddings
32+
"""
33+
34+
# Define data paths
35+
# MERSCOPE_DATA_DIR = Path('/omics/odcf/analysis/OE0606_projects_temp/MERSCOPE_projects/20241209_MERSCOPE5k_CNSL_BrM/20241209_MERSCOPE5k_CNSL_BrM/output-XETG00078__0041719__Region_1__20241203__142052')
36+
# SEGGER_DATA_DIR = Path('data_tidy/pyg_datasets/CNSL_5k')
37+
# # SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
38+
# CELLTYPE_COLUMN = 'celltype_minor'
39+
40+
41+
MERSCOPE_DATA_DIR = Path('data_raw/merscope/processed/')
42+
SEGGER_DATA_DIR = Path('data_tidy/pyg_datasets/merscope_liver')
43+
# SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/mimmo/MERSCOPE/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad')
44+
# CELLTYPE_COLUMN = 'annot_v1'
45+
46+
# Calculate gene-celltype embeddings from reference data
47+
# gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
48+
# sc.read(SCRNASEQ_FILE),
49+
# CELLTYPE_COLUMN
50+
# )
51+
52+
# Initialize spatial transcriptomics sample object
53+
sample = STSampleParquet(
54+
base_dir=MERSCOPE_DATA_DIR,
55+
n_workers=4,
56+
sample_type="merscope",
57+
buffer_ratio=1,
58+
# weights=gene_celltype_abundance_embedding
59+
)
60+
61+
# Load and filter data
62+
63+
64+
# Save processed dataset for SEGGER
65+
# Parameters:
66+
# - k_bd/dist_bd: Control nucleus boundary point connections
67+
# - k_tx/dist_tx: Control transcript neighborhood connections
68+
# - tile_width/height: Size of spatial tiles for processing
69+
# - neg_sampling_ratio: Ratio of negative to positive samples
70+
# - val_prob: Fraction of data for validation
71+
sample.save_debug(
72+
data_dir=SEGGER_DATA_DIR,
73+
k_bd=3, # Number of boundary points to connect
74+
dist_bd=15, # Maximum distance for boundary connections
75+
k_tx=5, # Use calculated optimal transcript neighbors
76+
dist_tx=20, # Use calculated optimal search radius
77+
tile_width=500, # Tile size for processing
78+
tile_height=500,
79+
neg_sampling_ratio=5.0, # 5:1 negative:positive samples
80+
frac=1.0, # Use all data
81+
val_prob=0.3, # 30% validation set
82+
test_prob=0, # No test set
83+
)

scripts/train_MNG_5k.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
DATA_ROOT="data_tidy/pyg_datasets/MNG_5k_sampled"
3+
4+
for folder in "$DATA_ROOT"/*; do
5+
if [ -d "$folder" ]; then
6+
echo "Submitting job for $folder"
7+
bsub -o train_yiheng_5k \
8+
-gpu num=4:j_exclusive=yes:gmem=20.7G \
9+
-R "rusage[mem=100GB]" \
10+
-q gpu-debian \
11+
python /dkfz/cluster/gpu/data/OE0606/elihei/segger_dev/scripts/train_model.py --data_dir "$folder"
12+
fi
13+
done

0 commit comments

Comments
 (0)