Skip to content

Commit b629684

Browse files
authored
Merge pull request #99 from EliHei2/cosmx_run
CosMx Data Support and Transcript Processing Improvements
2 parents 1a46ea4 + 7a5285e commit b629684

File tree

14 files changed

+774
-232
lines changed

14 files changed

+774
-232
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# How segger Works
1010

1111
![Segger Model](docs/images/Segger_model_08_2024.png)
12-
Some illustrations (cells and data) are borrowed from [Biorender](https://www.biorender.com/) and [BIDCell's paper](https://www.nature.com/articles/s41467-023-44560-w).
12+
Some illustrations (cells and data) are borrowed from [Biorender](https://www.biorender.com/) and [BIDCell's paper](https://www.nature.com/articles/s41467-023-44560-w).
1313

1414
---
1515

scripts/create_data_cosmx.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from segger.data.parquet.sample import STSampleParquet, STInMemoryDataset
2+
from path import Path
3+
from segger.data.utils import calculate_gene_celltype_abundance_embedding
4+
import scanpy as sc
5+
import pandas as pd
6+
import math
7+
import numpy as np
8+
from segger.data.parquet._utils import get_polygons_from_xy
9+
10+
"""
11+
This script preprocesses Xenium spatial transcriptomics data for SEGGER cell segmentation model.
12+
13+
Key steps:
14+
1. Data Loading:
15+
- Loads scRNA-seq reference data to create gene-celltype embeddings
16+
- Imports Xenium transcripts and nucleus boundaries
17+
18+
2. Parameter Optimization:
19+
- Calculates optimal neighborhood parameters based on tissue characteristics
20+
- dist_tx: Sets transcript neighbor search radius to 1/4 of typical nucleus size
21+
- k_tx: Determines number of transcripts to sample based on local density
22+
23+
3. Dataset Creation:
24+
- Filters transcripts to those overlapping nuclei
25+
- Creates graph connections between nearby transcripts
26+
- Splits data into training/validation sets
27+
- Saves in PyG format for SEGGER training
28+
29+
Usage:
30+
- Input: Raw Xenium data (transcripts.parquet, nucleus_boundaries.parquet)
31+
- Output: Processed dataset with graph structure and embeddings
32+
"""
33+
34+
# Define data paths
35+
# XENIUM_DATA_DIR = Path('/omics/odcf/analysis/OE0606_projects_temp/xenium_projects/20241209_Xenium5k_CNSL_BrM/20241209_Xenium5k_CNSL_BrM/output-XETG00078__0041719__Region_1__20241203__142052')
36+
# SEGGER_DATA_DIR = Path('data_tidy/pyg_datasets/CNSL_5k')
37+
# # SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
38+
# CELLTYPE_COLUMN = 'celltype_minor'
39+
40+
41+
XENIUM_DATA_DIR = Path("data_raw/cosmx/human_pancreas/processed/")
42+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/cosmx_pancreas_50")
43+
# SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/mimmo/Xenium/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad')
44+
# CELLTYPE_COLUMN = 'annot_v1'
45+
46+
# Calculate gene-celltype embeddings from reference data
47+
# gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
48+
# sc.read(SCRNASEQ_FILE),
49+
# CELLTYPE_COLUMN
50+
# )
51+
52+
# Initialize spatial transcriptomics sample object
53+
sample = STSampleParquet(
54+
base_dir=XENIUM_DATA_DIR,
55+
n_workers=4,
56+
sample_type="cosmx",
57+
buffer_ratio=1,
58+
# weights=gene_celltype_abundance_embedding
59+
)
60+
61+
# Load and filter data
62+
transcripts = pd.read_parquet(XENIUM_DATA_DIR / "transcripts.parquet")
63+
boundaries = pd.read_parquet(XENIUM_DATA_DIR / "nucleus_boundaries.parquet")
64+
65+
# Calculate optimal neighborhood parameters
66+
transcript_counts = transcripts.groupby("cell").size()
67+
nucleus_polygons = get_polygons_from_xy(boundaries, "x_global_px", "y_global_px", "cell")
68+
69+
transcript_densities = nucleus_polygons[transcript_counts.index].area / transcript_counts
70+
nucleus_diameter = nucleus_polygons.minimum_bounding_radius().median() * 2
71+
72+
# Set neighborhood parameters
73+
dist_tx = nucleus_diameter / 4 # Search radius = 1/4 nucleus diameter
74+
k_tx = math.ceil(
75+
np.quantile(dist_tx**2 * np.pi * transcript_densities, 0.9)
76+
) # Sample size based on 90th percentile density
77+
78+
print(f"Calculated parameters: k_tx={k_tx}, dist_tx={dist_tx:.2f}")
79+
80+
# Save processed dataset for SEGGER
81+
# Parameters:
82+
# - k_bd/dist_bd: Control nucleus boundary point connections
83+
# - k_tx/dist_tx: Control transcript neighborhood connections
84+
# - tile_width/height: Size of spatial tiles for processing
85+
# - neg_sampling_ratio: Ratio of negative to positive samples
86+
# - val_prob: Fraction of data for validation
87+
sample.save(
88+
data_dir=SEGGER_DATA_DIR,
89+
k_bd=3, # Number of boundary points to connect
90+
dist_bd=15, # Maximum distance for boundary connections
91+
k_tx=20, # Use calculated optimal transcript neighbors
92+
dist_tx=70, # Use calculated optimal search radius
93+
tile_width=500, # Tile size for processing
94+
tile_height=500,
95+
neg_sampling_ratio=5.0, # 5:1 negative:positive samples
96+
frac=1.0, # Use all data
97+
val_prob=0.3, # 30% validation set
98+
test_prob=0, # No test set
99+
)

scripts/create_data_fast_sample.py

Lines changed: 81 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,97 +7,93 @@
77
import numpy as np
88
from segger.data.parquet._utils import get_polygons_from_xy
99

10-
xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/')
11-
segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_200_final')
12-
13-
14-
scrnaseq_file = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
15-
celltype_column = 'celltype_minor'
16-
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
17-
sc.read(scrnaseq_file),
18-
celltype_column
10+
"""
11+
This script preprocesses Xenium spatial transcriptomics data for SEGGER cell segmentation model.
12+
13+
Key steps:
14+
1. Data Loading:
15+
- Loads scRNA-seq reference data to create gene-celltype embeddings
16+
- Imports Xenium transcripts and nucleus boundaries
17+
18+
2. Parameter Optimization:
19+
- Calculates optimal neighborhood parameters based on tissue characteristics
20+
- dist_tx: Sets transcript neighbor search radius to 1/4 of typical nucleus size
21+
- k_tx: Determines number of transcripts to sample based on local density
22+
23+
3. Dataset Creation:
24+
- Filters transcripts to those overlapping nuclei
25+
- Creates graph connections between nearby transcripts
26+
- Splits data into training/validation sets
27+
- Saves in PyG format for SEGGER training
28+
29+
Usage:
30+
- Input: Raw Xenium data (transcripts.parquet, nucleus_boundaries.parquet)
31+
- Output: Processed dataset with graph structure and embeddings
32+
"""
33+
34+
# Define data paths
35+
# XENIUM_DATA_DIR = Path('/omics/odcf/analysis/OE0606_projects_temp/xenium_projects/20241209_Xenium5k_CNSL_BrM/20241209_Xenium5k_CNSL_BrM/output-XETG00078__0041719__Region_1__20241203__142052')
36+
# SEGGER_DATA_DIR = Path('data_tidy/pyg_datasets/CNSL_5k')
37+
# # SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
38+
# CELLTYPE_COLUMN = 'celltype_minor'
39+
40+
41+
XENIUM_DATA_DIR = Path(
42+
"/omics/odcf/analysis/OE0606_projects_temp/oncolgy_data_exchange/analysis_domenico/project_24/output-XETG00423__0053177__mng_04_TMA__20250306__170821"
1943
)
44+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/MNG_0053177")
45+
SCRNASEQ_FILE = Path("/omics/groups/OE0606/internal/mimmo/Xenium/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad")
46+
CELLTYPE_COLUMN = "annot_v1"
2047

48+
# Calculate gene-celltype embeddings from reference data
49+
# gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
50+
# sc.read(SCRNASEQ_FILE),
51+
# CELLTYPE_COLUMN
52+
# )
53+
54+
# Initialize spatial transcriptomics sample object
2155
sample = STSampleParquet(
22-
base_dir=xenium_data_dir,
56+
base_dir=XENIUM_DATA_DIR,
2357
n_workers=4,
2458
sample_type="xenium",
25-
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
26-
)
27-
28-
transcripts = pd.read_parquet(xenium_data_dir / "transcripts.parquet", filters=[[("overlaps_nucleus", "=", 1)]])
29-
boundaries = pd.read_parquet(xenium_data_dir / "nucleus_boundaries.parquet")
30-
31-
sizes = transcripts.groupby("cell_id").size()
32-
polygons = get_polygons_from_xy(boundaries, "vertex_x", "vertex_y", "cell_id")
33-
densities = polygons[sizes.index].area / sizes
34-
bd_width = polygons.minimum_bounding_radius().median() * 2
35-
36-
# 1/4 median boundary diameter
37-
dist_tx = bd_width / 4
38-
# 90th percentile density of bounding circle with radius=dist_tx
39-
k_tx = math.ceil(np.quantile(dist_tx**2 * np.pi * densities, 0.9))
40-
41-
print(k_tx)
42-
print(dist_tx)
43-
44-
45-
sample.save(
46-
data_dir=segger_data_dir,
47-
k_bd=3,
48-
dist_bd=15,
49-
k_tx=3,
50-
dist_tx=5,
51-
tile_width=200,
52-
tile_height=200,
53-
neg_sampling_ratio=5.0,
54-
frac=1.0,
55-
val_prob=0.3,
56-
test_prob=0,
57-
)
58-
59-
60-
xenium_data_dir = Path('data_tidy/bc_5k')
61-
segger_data_dir = Path('data_tidy/pyg_datasets/bc_5k_emb_new')
62-
63-
64-
65-
sample = STSampleParquet(
66-
base_dir=xenium_data_dir,
67-
n_workers=8,
68-
sample_type='xenium',
69-
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
59+
# weights=gene_celltype_abundance_embedding
7060
)
7161

72-
73-
transcripts = pd.read_parquet(xenium_data_dir / "transcripts.parquet", filters=[[("overlaps_nucleus", "=", 1)]])
74-
boundaries = pd.read_parquet(xenium_data_dir / "nucleus_boundaries.parquet")
75-
76-
sizes = transcripts.groupby("cell_id").size()
77-
polygons = get_polygons_from_xy(boundaries, "vertex_x", "vertex_y", "cell_id")
78-
densities = polygons[sizes.index].area / sizes
79-
bd_width = polygons.minimum_bounding_radius().median() * 2
80-
81-
# 1/4 median boundary diameter
82-
dist_tx = bd_width / 4
83-
# 90th percentile density of bounding circle with radius=dist_tx
84-
k_tx = math.ceil(np.quantile(dist_tx**2 * np.pi * densities, 0.9))
85-
86-
print(k_tx)
87-
print(dist_tx)
88-
89-
62+
# Load and filter data
63+
transcripts = pd.read_parquet(XENIUM_DATA_DIR / "transcripts.parquet", filters=[[("overlaps_nucleus", "=", 1)]])
64+
boundaries = pd.read_parquet(XENIUM_DATA_DIR / "nucleus_boundaries.parquet")
65+
66+
# Calculate optimal neighborhood parameters
67+
transcript_counts = transcripts.groupby("cell_id").size()
68+
nucleus_polygons = get_polygons_from_xy(boundaries, "vertex_x", "vertex_y", "cell_id")
69+
transcript_densities = nucleus_polygons[transcript_counts.index].area / transcript_counts
70+
nucleus_diameter = nucleus_polygons.minimum_bounding_radius().median() * 2
71+
72+
# Set neighborhood parameters
73+
dist_tx = nucleus_diameter / 4 # Search radius = 1/4 nucleus diameter
74+
k_tx = math.ceil(
75+
np.quantile(dist_tx**2 * np.pi * transcript_densities, 0.9)
76+
) # Sample size based on 90th percentile density
77+
78+
print(f"Calculated parameters: k_tx={k_tx}, dist_tx={dist_tx:.2f}")
79+
80+
# Save processed dataset for SEGGER
81+
# Parameters:
82+
# - k_bd/dist_bd: Control nucleus boundary point connections
83+
# - k_tx/dist_tx: Control transcript neighborhood connections
84+
# - tile_width/height: Size of spatial tiles for processing
85+
# - neg_sampling_ratio: Ratio of negative to positive samples
86+
# - val_prob: Fraction of data for validation
9087
sample.save(
91-
data_dir=segger_data_dir,
92-
k_bd=3,
93-
dist_bd=15.0,
94-
k_tx=15,
95-
dist_tx=3,
96-
tile_size=50_000,
97-
neg_sampling_ratio=5.0,
98-
frac=0.1,
99-
val_prob=0.1,
100-
test_prob=0.1,
88+
data_dir=SEGGER_DATA_DIR,
89+
k_bd=3, # Number of boundary points to connect
90+
dist_bd=15, # Maximum distance for boundary connections
91+
k_tx=k_tx, # Use calculated optimal transcript neighbors
92+
dist_tx=dist_tx, # Use calculated optimal search radius
93+
tile_width=100, # Tile size for processing
94+
tile_height=100,
95+
neg_sampling_ratio=5.0, # 5:1 negative:positive samples
96+
frac=1.0, # Use all data
97+
val_prob=0.3, # 30% validation set
98+
test_prob=0, # No test set
10199
)
102-
103-

scripts/predict_model_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
seg_tag = "bc_fast_data_emb_major"
2323
model_version = 1
2424

25-
segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag
26-
models_dir = Path("./models") / seg_tag
25+
segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag
26+
models_dir = Path("./models") / seg_tag
2727
benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc")
2828
transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet"
2929
# Initialize the Lightning data module

scripts/train_model_sample.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
import os
1616

1717

18-
segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_final_200')
19-
models_dir = Path("./models/bc_rep1_emb_final_200")
18+
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/cosmx_pancreas")
19+
models_dir = Path("./models/cosmx_pancreas")
2020

2121
# Base directory to store Pytorch Lightning models
2222
# models_dir = Path('models')
2323

2424
# Initialize the Lightning data module
2525
dm = SeggerDataModule(
2626
data_dir=segger_data_dir,
27-
batch_size=1,
27+
batch_size=2,
2828
num_workers=2,
2929
)
3030

@@ -35,37 +35,34 @@
3535

3636
# If you use custom gene embeddings, use the following two lines instead:
3737
is_token_based = False
38-
num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes
38+
num_tx_tokens = len(dm.train[0].x_dict["tx"]) # Set the number of tokens to the number of genes
3939

4040

4141
num_bd_features = dm.train[0].x_dict["bd"].shape[1]
4242

4343
# Initialize the Lightning model
4444
ls = LitSegger(
45-
is_token_based = is_token_based,
46-
num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features},
47-
init_emb=8,
45+
is_token_based=is_token_based,
46+
num_node_features={"tx": num_tx_tokens, "bd": num_bd_features},
47+
init_emb=8,
4848
hidden_channels=64,
4949
out_channels=16,
5050
heads=4,
5151
num_mid_layers=3,
52-
aggr='sum',
53-
learning_rate=1e-3
52+
aggr="sum",
53+
learning_rate=1e-3,
5454
)
5555

5656
# Initialize the Lightning trainer
5757
trainer = Trainer(
58-
accelerator='cuda',
59-
strategy='auto',
60-
precision='16-mixed',
61-
devices=2, # set higher number if more gpus are available
62-
max_epochs=400,
58+
accelerator="cpu",
59+
strategy="auto",
60+
precision="16-mixed",
61+
devices=4, # set higher number if more gpus are available
62+
max_epochs=100,
6363
default_root_dir=models_dir,
6464
logger=CSVLogger(models_dir),
6565
)
6666

6767

68-
trainer.fit(
69-
model=ls,
70-
datamodule=dm
71-
)
68+
trainer.fit(model=ls, datamodule=dm)

0 commit comments

Comments
 (0)