Skip to content

Commit 37253a1

Browse files
committed
fixes filtering the substrings as reported in #89
1 parent 56e7064 commit 37253a1

File tree

8 files changed

+77
-33
lines changed

8 files changed

+77
-33
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,7 @@ dev*
176176

177177
# Custom
178178
*_old*
179-
.dev
179+
.dev
180+
181+
scripts/*
182+
.scripts/*

docs/notebooks/segger_tutorial.ipynb

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,29 @@
358358
"- **`--precision`**: Enables mixed precision training (e.g., `16-mixed`), which can speed up training and reduce memory usage while maintaining accuracy."
359359
]
360360
},
361+
{
362+
"cell_type": "code",
363+
"execution_count": null,
364+
"id": "cfff5dca",
365+
"metadata": {},
366+
"outputs": [],
367+
"source": [
368+
"# Evaluate results\n",
369+
"model_version = 0 # 'v_num' from training output above\n",
370+
"model_path = Path('../human_CRC') / 'lightning_logs' / f'version_{model_version}'\n",
371+
"metrics = pd.read_csv(model_path / 'metrics.csv', index_col=1)\n",
372+
"\n",
373+
"fig, ax = plt.subplots(1,1, figsize=(2,2))\n",
374+
"\n",
375+
"for col in metrics.columns.difference(['epoch']):\n",
376+
" metric = metrics[col].dropna()\n",
377+
" ax.plot(metric.index, metric.values, label=col)\n",
378+
"\n",
379+
"ax.legend(loc=(1, 0.33))\n",
380+
"ax.set_ylim(0, 1)\n",
381+
"ax.set_xlabel('Step')"
382+
]
383+
},
361384
{
362385
"cell_type": "markdown",
363386
"id": "9a7d20c6-ca16-4beb-b627-afb41e3fb491",
@@ -461,6 +484,14 @@
461484
"Below is an example of how to run the faster Segger prediction pipeline using the command line:"
462485
]
463486
},
487+
{
488+
"cell_type": "code",
489+
"execution_count": null,
490+
"id": "cdda303d",
491+
"metadata": {},
492+
"outputs": [],
493+
"source": []
494+
},
464495
{
465496
"cell_type": "code",
466497
"execution_count": null,

scripts/create_data_cosmx.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
XENIUM_DATA_DIR = Path("data_raw/cosmx/human_pancreas/processed/")
42-
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/cosmx_pancreas_50")
42+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/cosmx_pancreas_fixed_")
4343
# SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/mimmo/Xenium/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad')
4444
# CELLTYPE_COLUMN = 'annot_v1'
4545

@@ -92,11 +92,11 @@
9292
data_dir=SEGGER_DATA_DIR,
9393
k_bd=3, # Number of boundary points to connect
9494
dist_bd=15, # Maximum distance for boundary connections
95-
k_tx=20, # Use calculated optimal transcript neighbors
95+
k_tx=5, # Use calculated optimal transcript neighbors
9696
dist_tx=70, # Use calculated optimal search radius
97-
tile_width=500, # Tile size for processing
98-
tile_height=500,
99-
neg_sampling_ratio=5.0, # 5:1 negative:positive samples
97+
tile_width=1000, # Tile size for processing,
98+
tile_height=1000, # Tile size for processing
99+
neg_sampling_ratio=10.0, # 5:1 negative:positive samples
100100
frac=1.0, # Use all data
101101
val_prob=0.3, # 30% validation set
102102
test_prob=0, # No test set

scripts/create_data_fast_sample.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@
3939

4040

4141
XENIUM_DATA_DIR = Path(
42-
"/omics/odcf/analysis/OE0606_projects_temp/oncolgy_data_exchange/analysis_domenico/project_24/output-XETG00423__0053177__mng_04_TMA__20250306__170821"
42+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC"
4343
)
44-
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/MNG_0053177")
45-
SCRNASEQ_FILE = Path(
46-
"/omics/groups/OE0606/internal/mimmo/Xenium/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad"
47-
)
48-
CELLTYPE_COLUMN = "annot_v1"
44+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_full")
45+
# SCRNASEQ_FILE = Path(
46+
# "/omics/groups/OE0606/internal/mimmo/Xenium/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad"
47+
# )
48+
# CELLTYPE_COLUMN = "annot_v1"
4949

5050
# Calculate gene-celltype embeddings from reference data
5151
# gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
@@ -94,10 +94,10 @@
9494
data_dir=SEGGER_DATA_DIR,
9595
k_bd=3, # Number of boundary points to connect
9696
dist_bd=15, # Maximum distance for boundary connections
97-
k_tx=k_tx, # Use calculated optimal transcript neighbors
98-
dist_tx=dist_tx, # Use calculated optimal search radius
99-
tile_width=100, # Tile size for processing
100-
tile_height=100,
97+
k_tx=5, # Use calculated optimal transcript neighbors
98+
dist_tx=5, # Use calculated optimal search radius
99+
tile_width=200, # Tile size for processing
100+
tile_height=200,
101101
neg_sampling_ratio=5.0, # 5:1 negative:positive samples
102102
frac=1.0, # Use all data
103103
val_prob=0.3, # 30% validation set

scripts/predict_model_sample.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,33 @@
1010
from pathlib import Path
1111

1212
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13+
os.environ["CUPY_CACHE_DIR"] = "./.cupy"
1314
import cupy as cp
1415
from dask.distributed import Client, LocalCluster
1516
from dask_cuda import LocalCUDACluster
1617
import dask.dataframe as dd
1718

1819

19-
seg_tag = "bc_rep1_emb_final"
20-
model_version = 6
20+
seg_tag = "human_CRC"
21+
model_version = 0
2122

22-
seg_tag = "bc_fast_data_emb_major"
23-
model_version = 1
23+
seg_tag = "human_CRC"
24+
model_version = 0
25+
26+
27+
28+
XENIUM_DATA_DIR = Path(
29+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC"
30+
)
31+
SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_full")
2432

2533
segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag
2634
models_dir = Path("./models") / seg_tag
2735
benchmarks_dir = Path(
28-
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc"
36+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC"
2937
)
3038
transcripts_file = (
31-
"data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet"
39+
"/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC/transcripts.parquet"
3240
)
3341
# Initialize the Lightning data module
3442
dm = SeggerDataModule(

scripts/train_model_sample.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from segger.data.io import XeniumSample
2-
from segger.training.train import LitSegger
31
from segger.training.segger_data_module import SeggerDataModule
42
# from segger.prediction.predict import predict, load_model
53
from segger.models.segger_model import Segger
@@ -19,8 +17,8 @@
1917

2018

2119

22-
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/cosmx_pancreas")
23-
models_dir = Path("./models/cosmx_pancreas")
20+
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_full")
21+
models_dir = Path("./models/human_CRC")
2422

2523
# Base directory to store Pytorch Lightning models
2624
# models_dir = Path('models')
@@ -46,7 +44,7 @@
4644

4745
model = Segger(
4846
# is_token_based=is_token_based,
49-
num_tx_tokens= 25000,
47+
num_tx_tokens= 850,
5048
init_emb=8,
5149
hidden_channels=64,
5250
out_channels=16,
@@ -73,14 +71,14 @@
7371

7472
# Initialize the Lightning trainer
7573
trainer = Trainer(
76-
accelerator="cpu",
74+
accelerator="gpu",
7775
strategy="auto",
7876
precision="16-mixed",
79-
devices=2, # set higher number if more gpus are available
80-
max_epochs=400,
77+
devices=4, # set higher number if more gpus are available
78+
max_epochs=100,
8179
default_root_dir=models_dir,
8280
logger=CSVLogger(models_dir),
8381
)
8482

8583

86-
trainer.fit(ls , datamodule=dm)
84+
trainer.fit(ls, datamodule=dm)

src/segger/data/parquet/_settings/xenium.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ transcripts:
44
y: "y_location"
55
z: "z_location"
66
id: "transcript_id"
7-
label: "target"
7+
label: "feature_name"
88
nuclear_column: "overlaps_nucleus"
99
nuclear_value: 1
1010
qv_column: "qv"

src/segger/data/parquet/sample.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from itertools import compress
1515
from torch_geometric.data import HeteroData
1616
from torch_geometric.transforms import RandomLinkSplit
17+
from pqdm.threads import pqdm
1718
import torch
1819
import random
1920
from segger.data.parquet.transcript_embedding import TranscriptEmbedding
21+
# import re
2022

2123

2224
# TODO: Add documentation for settings
@@ -203,7 +205,8 @@ def transcripts_metadata(self) -> dict:
203205
missing_genes = list(set(names_str) - set(self._emb_genes))
204206
logging.warning(f"Number of missing genes: {len(missing_genes)}")
205207
self.settings.transcripts.filter_substrings.extend(missing_genes)
206-
pattern = "|".join(self.settings.transcripts.filter_substrings)
208+
# pattern = "|".join(self.settings.transcripts.filter_substrings)
209+
pattern = "|".join(f"^{s}" for s in self.settings.transcripts.filter_substrings)
207210
mask = pc.invert(pc.match_substring_regex(names, pattern))
208211
filtered_names = pc.filter(names, mask).to_pylist()
209212
metadata["feature_names"] = [
@@ -674,6 +677,7 @@ def _load_transcripts(self, path: os.PathLike, min_qv: float = 30.0):
674677
transcripts[self.settings.transcripts.label] = transcripts[
675678
self.settings.transcripts.label
676679
].apply(lambda x: x.decode("utf-8") if isinstance(x, bytes) else x)
680+
qv_column = getattr(self.settings.transcripts, "qv_column", None)
677681
transcripts = utils.filter_transcripts(
678682
transcripts,
679683
self.settings.transcripts.label,

0 commit comments

Comments
 (0)