diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4620eaa9..a3b41d0b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,8 @@ minimum_pre_commit_version: 2.16.0 ci: skip: [] repos: - - repo: https://github.com/psf/black - rev: 24.10.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.9.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier @@ -17,6 +17,6 @@ repos: hooks: - id: prettier - repo: https://github.com/asottile/blacken-docs - rev: 1.19.1 + rev: 1.20.0 hooks: - id: blacken-docs diff --git a/analysis_summary.html b/analysis_summary.html index 78629d01..06cd4084 100755 --- a/analysis_summary.html +++ b/analysis_summary.html @@ -1,88 +1,5820 @@ - + - + - -
-
-
-
- -
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- - -
-
-
-
+ class="navbar-wrapper insitu_web_summary_cs" + data-render-header="false" + >
+
+
+
+
+ +
+
-
+
-
+
+
+
+
+
+
+
+
+
+
+
+
-
+
+ + +
+
-
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
- -
- - -
+ +
+ + +
@@ -93,54 +5825,55 @@
- - -
- -
-
-
-
-
-
-
-
-
-
-
+
-
-
+ +
-
-
+
+
+
+
+
+
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+ + +
- - -
-
- - -
+
@@ -221,115 +5954,127 @@
+
-
+ - - -
- - - - - - - - - - - - - - - - - - - -
- -
-
-
-
-
-
-
-
+
+ + + + + + + + + + + + + + + + + + -
- -
-
-
-
-
-
+
+ +
+
+
+
+
+
+
-
- -
-
-
-
-
-
+
+ +
+
+
+
+
+
+
+
+ + +
+
+
+
+
+
+
+
-
-
-
- -
-
-
-
-
+ +
+
+
+
+
+
+
+
-
-
- + -
-
-
-
-
-
+
-
-
-
+
+
+
+
+
+
+
+
- -
+ +
@@ -337,59 +6082,61 @@
- -
- -
- -
-
-
-
-
-
-
- -
-
-
-
-
-
-
-
+ +
+
+
+
+
+
+
+
-
- -
-
-
+ +
+
+
+
+
+
+
+
+
-
+ +
+
+
+
+
+
+
+
-
-
- - -
+ + +
@@ -415,10 +6162,9 @@
- - - -
+ + +
@@ -450,16 +6196,301352 @@
+
+
-
-
+ + - - \ No newline at end of file + diff --git a/scripts/0_data_creation_5k_nucleus.py b/scripts/0_data_creation_5k_nucleus.py index 3751201b..30ac585a 100644 --- a/scripts/0_data_creation_5k_nucleus.py +++ b/scripts/0_data_creation_5k_nucleus.py @@ -35,19 +35,15 @@ """ - XENIUM_DATA_DIR = Path( "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" ) SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") -SCRNASEQ_FILE = Path( - "data_tidy/Human_CRC/scRNAseq.h5ad" -) -CELLTYPE_COLUMN = "Level1" # change this to your column name +SCRNASEQ_FILE = Path("data_tidy/Human_CRC/scRNAseq.h5ad") +CELLTYPE_COLUMN = "Level1" # change this to your column name scrnaseq = sc.read(SCRNASEQ_FILE) - # subsample the scRNAseq if needed # sc.pp.subsample(scrnaseq, 0.1) # scrnaseq.var_names_make_unique() @@ -55,8 +51,7 @@ # Calculate gene-celltype embeddings from reference data gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( - scrnaseq, - CELLTYPE_COLUMN + scrnaseq, CELLTYPE_COLUMN ) # Initialize spatial transcriptomics sample object @@ -65,7 +60,7 @@ n_workers=4, sample_type="xenium", weights=gene_celltype_abundance_embedding, - scale_factor=1. + scale_factor=1.0, ) @@ -77,7 +72,7 @@ dist_tx=5, # Use calculated optimal search radius tile_size=10000, # Tile size for processing # tile_height=50, - neg_sampling_ratio=10., # 5:1 negative:positive samples + neg_sampling_ratio=10.0, # 5:1 negative:positive samples frac=1.0, # Use all data val_prob=0.3, # 30% validation set test_prob=0, # No test set diff --git a/scripts/1_train_5k.py b/scripts/1_train_5k.py index f621a3be..c88ed822 100644 --- a/scripts/1_train_5k.py +++ b/scripts/1_train_5k.py @@ -1,4 +1,5 @@ from segger.training.segger_data_module import SeggerDataModule + # from segger.prediction.predict import predict, load_model from segger.models.segger_model import Segger from segger.training.train import LitSegger @@ -9,14 +10,15 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding + # import scanpy as sc import os from lightning import LightningModule - segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_cells") models_dir = Path("./models/human_CRC_seg_cells") @@ -43,14 +45,18 @@ model = Segger( - num_tx_tokens= num_tx_tokens, + num_tx_tokens=num_tx_tokens, init_emb=8, hidden_channels=32, out_channels=16, heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) diff --git a/scripts/2_predict_5k.py b/scripts/2_predict_5k.py index f93f311c..5967b6d5 100644 --- a/scripts/2_predict_5k.py +++ b/scripts/2_predict_5k.py @@ -8,26 +8,27 @@ import dask.dataframe as dd import pandas as pd from pathlib import Path + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["CUPY_CACHE_DIR"] = "./.cupy" -XENIUM_DATA_DIR = Path( #raw data dir +XENIUM_DATA_DIR = Path( # raw data dir "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" ) -transcripts_file = ( - XENIUM_DATA_DIR / "transcripts.parquet" -) +transcripts_file = XENIUM_DATA_DIR / "transcripts.parquet" -SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") # preprocessed data dir +SEGGER_DATA_DIR = Path( + "data_tidy/pyg_datasets/human_CRC_seg_nuclei" +) # preprocessed data dir seg_tag = "human_CRC_seg_nuclei" model_version = 0 -models_dir = Path("./models") / seg_tag #trained model dir +models_dir = Path("./models") / seg_tag # trained model dir -output_dir = Path( #output dir +output_dir = Path( # output dir "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_nuclei" ) @@ -58,10 +59,10 @@ min_transcripts=5, score_cut=0.5, cell_id_col="segger_cell_id", - save_transcripts= True, - save_anndata= True, - save_cell_masks= False, # Placeholder for future implementation - use_cc=False, # if one wants fragments (groups of similar transcripts not attached to any nuclei) + save_transcripts=True, + save_anndata=True, + save_cell_masks=False, # Placeholder for future implementation + use_cc=False, # if one wants fragments (groups of similar transcripts not attached to any nuclei) knn_method="kd_tree", verbose=True, gpu_ids=["0"], diff --git a/scripts/create_data_cosmx.py b/scripts/create_data_cosmx.py index 2809d063..d6e78877 100644 --- a/scripts/create_data_cosmx.py +++ b/scripts/create_data_cosmx.py @@ -69,7 +69,7 @@ ) -cells = list(set(transcript_counts.index) & set(nucleus_polygons.index)) +cells = list(set(transcript_counts.index) & set(nucleus_polygons.index)) nucleus_polygons = nucleus_polygons[cells] transcript_counts = transcript_counts[cells] diff --git a/scripts/create_data_fast_sample.py b/scripts/create_data_fast_sample.py index 17325313..b9f90f75 100644 --- a/scripts/create_data_fast_sample.py +++ b/scripts/create_data_fast_sample.py @@ -42,17 +42,14 @@ "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" ) SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/human_CRC_seg_nuclei") -SCRNASEQ_FILE = Path( - "data_tidy/Human_CRC/scRNAseq.h5ad" -) +SCRNASEQ_FILE = Path("data_tidy/Human_CRC/scRNAseq.h5ad") CELLTYPE_COLUMN = "Level1" scrnaseq = sc.read(SCRNASEQ_FILE) sc.pp.subsample(scrnaseq, 0.1) scrnaseq.var_names_make_unique() # Calculate gene-celltype embeddings from reference data gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( - scrnaseq, - CELLTYPE_COLUMN + scrnaseq, CELLTYPE_COLUMN ) # Initialize spatial transcriptomics sample object @@ -61,7 +58,7 @@ n_workers=4, sample_type="xenium", # scale_factor=0.8, - weights=gene_celltype_abundance_embedding + weights=gene_celltype_abundance_embedding, ) # # Load and filter datas diff --git a/scripts/create_data_merscope.py b/scripts/create_data_merscope.py index 18de3ad9..79ca8617 100644 --- a/scripts/create_data_merscope.py +++ b/scripts/create_data_merscope.py @@ -38,8 +38,8 @@ # CELLTYPE_COLUMN = 'celltype_minor' -MERSCOPE_DATA_DIR = Path('data_raw/merscope/processed/') -SEGGER_DATA_DIR = Path('data_tidy/pyg_datasets/merscope_liver') +MERSCOPE_DATA_DIR = Path("data_raw/merscope/processed/") +SEGGER_DATA_DIR = Path("data_tidy/pyg_datasets/merscope_liver") # SCRNASEQ_FILE = Path('/omics/groups/OE0606/internal/mimmo/MERSCOPE/notebooks/data/scData/bh/bh_mng_scdata_20250306.h5ad') # CELLTYPE_COLUMN = 'annot_v1' @@ -80,4 +80,4 @@ frac=1.0, # Use all data val_prob=0.3, # 30% validation set test_prob=0, # No test set -) \ No newline at end of file +) diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index 249bb439..5d915ba9 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -17,12 +17,10 @@ import dask.dataframe as dd - seg_tag = "human_CRC_seg_cells" model_version = 0 - XENIUM_DATA_DIR = Path( "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real" ) @@ -32,9 +30,7 @@ benchmarks_dir = Path( "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/human_CRC_seg_cells" ) -transcripts_file = ( - "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real/transcripts.parquet" -) +transcripts_file = "/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_raw/xenium_seg_kit/human_CRC_real/transcripts.parquet" # Initialize the Lightning data module dm = SeggerDataModule( data_dir=SEGGER_DATA_DIR, diff --git a/scripts/train_cosmx.py b/scripts/train_cosmx.py index d58ac3fe..bb435bd6 100644 --- a/scripts/train_cosmx.py +++ b/scripts/train_cosmx.py @@ -1,4 +1,5 @@ from segger.training.segger_data_module import SeggerDataModule + # from segger.prediction.predict import predict, load_model from segger.models.segger_model import Segger from segger.training.train import LitSegger @@ -9,14 +10,15 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding + # import scanpy as sc import os from lightning import LightningModule - segger_data_dir = Path("data_tidy/pyg_datasets/cosmx_pancreas_degbugged") models_dir = Path("./models/cosmx_pancreas") @@ -50,7 +52,11 @@ heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) diff --git a/scripts/train_mimmo_batch.py b/scripts/train_mimmo_batch.py index 7fe41e5a..4c711ff1 100644 --- a/scripts/train_mimmo_batch.py +++ b/scripts/train_mimmo_batch.py @@ -1,4 +1,5 @@ from segger.training.segger_data_module import SeggerDataModule + # from segger.prediction.predict import predict, load_model from segger.models.segger_model import Segger from segger.training.train import LitSegger @@ -9,16 +10,21 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding + # import scanpy as sc import os from lightning import LightningModule - -segger_data_dir = Path("data_tidy/pyg_datasets/project24_MNG/output-XETG00423__0042861__mng_07_TMA__20250303__153740") -models_dir = Path("./models/project24_MNG/output-XETG00423__0042861__mng_07_TMA__20250303__153740") +segger_data_dir = Path( + "data_tidy/pyg_datasets/project24_MNG/output-XETG00423__0042861__mng_07_TMA__20250303__153740" +) +models_dir = Path( + "./models/project24_MNG/output-XETG00423__0042861__mng_07_TMA__20250303__153740" +) # Base directory to store Pytorch Lightning models # models_dir = Path('models') @@ -44,14 +50,18 @@ model = Segger( # is_token_based=is_token_based, - num_tx_tokens= num_tx_tokens, + num_tx_tokens=num_tx_tokens, init_emb=8, hidden_channels=32, out_channels=16, heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) diff --git a/scripts/train_model.py b/scripts/train_model.py index 5ed0fd2b..6294eeed 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -8,7 +8,7 @@ from lightning import Trainer parser = argparse.ArgumentParser() -parser.add_argument('--data_dir', type=Path, required=True) +parser.add_argument("--data_dir", type=Path, required=True) args = parser.parse_args() segger_data_dir = args.data_dir @@ -32,7 +32,11 @@ heads=4, num_mid_layers=2, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) ls = LitSegger(model=model) @@ -46,4 +50,4 @@ logger=CSVLogger(models_dir), ) -trainer.fit(ls, datamodule=dm) \ No newline at end of file +trainer.fit(ls, datamodule=dm) diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index f64b39dc..34cd688c 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -1,4 +1,5 @@ from segger.training.segger_data_module import SeggerDataModule + # from segger.prediction.predict import predict, load_model from segger.models.segger_model import Segger from segger.training.train import LitSegger @@ -9,14 +10,15 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding + # import scanpy as sc import os from lightning import LightningModule - segger_data_dir = Path("data_tidy/pyg_datasets/human_CRC_seg_cells") models_dir = Path("./models/human_CRC_seg_cells") @@ -44,14 +46,18 @@ model = Segger( # is_token_based=is_token_based, - num_tx_tokens= num_tx_tokens, + num_tx_tokens=num_tx_tokens, init_emb=8, hidden_channels=64, out_channels=16, heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) diff --git a/scripts/train_yiheng_5k.py b/scripts/train_yiheng_5k.py index eed698c1..81f513dc 100644 --- a/scripts/train_yiheng_5k.py +++ b/scripts/train_yiheng_5k.py @@ -1,4 +1,5 @@ from segger.training.segger_data_module import SeggerDataModule + # from segger.prediction.predict import predict, load_model from segger.models.segger_model import Segger from segger.training.train import LitSegger @@ -9,16 +10,21 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding + # import scanpy as sc import os from lightning import LightningModule - -segger_data_dir = Path("data_tidy/pyg_datasets/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/") -models_dir = Path("./models/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/") +segger_data_dir = Path( + "data_tidy/pyg_datasets/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/" +) +models_dir = Path( + "./models/MNG_5k_sampled/output-XETG00078__0041719__Region_2__20241203__142052/" +) # Base directory to store Pytorch Lightning models # models_dir = Path('models') @@ -44,14 +50,18 @@ model = Segger( # is_token_based=is_token_based, - num_tx_tokens= num_tx_tokens, + num_tx_tokens=num_tx_tokens, init_emb=8, hidden_channels=32, out_channels=16, heads=4, num_mid_layers=3, ) -model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") +model = to_hetero( + model, + (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), + aggr="sum", +) batch = dm.train[0] model.forward(batch.x_dict, batch.edge_index_dict) diff --git a/src/segger/data/README.md b/src/segger/data/README.md index 28d7df0c..06170114 100644 --- a/src/segger/data/README.md +++ b/src/segger/data/README.md @@ -87,7 +87,6 @@ Where: For each tile, a graph $$G$$ is constructed with: - **Nodes ($$V$$)**: - - **Transcripts**: Represented by their spatial coordinates $$(x_t, y_t)$$ and feature vectors $$\mathbf{f}_t$$. - **Boundaries**: Represented by centroid coordinates $$(x_b, y_b)$$ and associated properties (e.g., area). diff --git a/src/segger/data/parquet/_utils.py b/src/segger/data/parquet/_utils.py index 1a09e9e2..5f04b743 100644 --- a/src/segger/data/parquet/_utils.py +++ b/src/segger/data/parquet/_utils.py @@ -129,8 +129,9 @@ def read_parquet_region( columns = list({x, y} | set(extra_columns)) # Check if 'Geometry', 'geometry', 'polygon', or 'Polygon' is in the columns - if any(col in columns for col in ['Geometry', 'geometry', 'polygon', 'Polygon']): + if any(col in columns for col in ["Geometry", "geometry", "polygon", "Polygon"]): import geopandas as gpd + # If geometry columns are present, read with geopandas region = gpd.read_parquet( filepath, @@ -199,7 +200,7 @@ def get_polygons_from_xy( # Scale polygons around their centroid gs = gpd.GeoSeries( [ - scale(geom, xfact=scale_factor, yfact=scale_factor, origin='centroid') + scale(geom, xfact=scale_factor, yfact=scale_factor, origin="centroid") for geom in gs ], index=gs.index, @@ -357,7 +358,11 @@ def filter_transcripts( mask = pd.Series(True, index=transcripts_df.index) if filter_substrings is not None and label is not None: mask &= ~transcripts_df[label].str.startswith(tuple(filter_substrings)) - if min_qv is not None and qv_column is not None and qv_column in transcripts_df.columns: + if ( + min_qv is not None + and qv_column is not None + and qv_column in transcripts_df.columns + ): mask &= transcripts_df[qv_column].ge(min_qv) return transcripts_df[mask] diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index c0aed561..95f2a243 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -18,6 +18,7 @@ import torch import random from segger.data.parquet.transcript_embedding import TranscriptEmbedding + # import re @@ -209,7 +210,9 @@ def transcripts_metadata(self) -> dict: logging.warning(f"Number of missing genes: {len(missing_genes)}") self.settings.transcripts.filter_substrings.extend(missing_genes) # pattern = "|".join(self.settings.transcripts.filter_substrings) - pattern = "|".join(f"^{s}" for s in self.settings.transcripts.filter_substrings) + pattern = "|".join( + f"^{s}" for s in self.settings.transcripts.filter_substrings + ) mask = pc.invert(pc.match_substring_regex(names, pattern)) filtered_names = pc.filter(names, mask).to_pylist() metadata["feature_names"] = [ @@ -1168,11 +1171,13 @@ def get_boundary_props( """ # Get polygons from coordinates # Use getattr to check for the geometry column - geometry_column = getattr(self.settings.boundaries, 'geometry', None) + geometry_column = getattr(self.settings.boundaries, "geometry", None) if geometry_column and geometry_column in self.boundaries.columns: polygons = self.boundaries[geometry_column] else: - polygons = self.boundaries['geometry'] # Assign None if the geometry column does not exist + polygons = self.boundaries[ + "geometry" + ] # Assign None if the geometry column does not exist # Geometric properties of polygons props = self.get_polygon_props(polygons) props = torch.as_tensor(props.values).float() @@ -1233,9 +1238,11 @@ def to_pyg_dataset( # Set up Boundary nodes # Check if boundaries have geometries - geometry_column = getattr(self.settings.boundaries, 'geometry', None) + geometry_column = getattr(self.settings.boundaries, "geometry", None) if geometry_column and geometry_column in self.boundaries.columns: - polygons = gpd.GeoSeries(self.boundaries[geometry_column], index=self.boundaries.index) + polygons = gpd.GeoSeries( + self.boundaries[geometry_column], index=self.boundaries.index + ) else: # Fallback: compute polygons polygons = utils.get_polygons_from_xy( diff --git a/src/segger/models/README.md b/src/segger/models/README.md index 033e5456..319e68b6 100644 --- a/src/segger/models/README.md +++ b/src/segger/models/README.md @@ -6,7 +6,6 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr 1. **Input Node Features**: For input node features \( \mathbf{x} \), the model distinguishes between one-dimensional (transcript) nodes and multi-dimensional (boundary or nucleus) nodes by checking the dimension of \( \mathbf{x} \). - - **Transcript Nodes**: If \( \mathbf{x} \) is 1-dimensional (e.g., for tokenized transcript data), the model applies an embedding layer: $$ @@ -14,7 +13,6 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr $$ where \( i \) is the transcript token index. - - **Nuclei or Cell Boundary Nodes**: If \( \mathbf{x} \) has multiple dimensions, the model applies a linear transformation: $$ @@ -31,13 +29,11 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr $$ where: - - \( \alpha\_{ij} \) is the attention coefficient between node \( i \) and node \( j \), computed as: $$ \alpha_{ij} = \frac{\exp\left( \text{LeakyReLU}\left( \mathbf{a}^{\top} [\mathbf{W}^{(l)} \mathbf{h}_{i}^{(l)} || \mathbf{W}^{(l)} \mathbf{h}_{j}^{(l)}] \right)\right)}{\sum_{k \in \mathcal{N}(i)} \exp\left( \text{LeakyReLU}\left( \mathbf{a}^{\top} [\mathbf{W}^{(l)} \mathbf{h}_{i}^{(l)} || \mathbf{W}^{(l)} \mathbf{h}_{k}^{(l)}] \right)\right)} $$ - - \( \mathbf{a} \) is a learnable attention vector. 3. **Residual Linear Connections**: diff --git a/src/segger/prediction/boundary.py b/src/segger/prediction/boundary.py index 8a3cc0d3..dbd3e72d 100644 --- a/src/segger/prediction/boundary.py +++ b/src/segger/prediction/boundary.py @@ -223,7 +223,8 @@ def calculate_part_2(self, plot=True): simplex_id = list(edges[current_edge]["simplices"].keys())[0] simplex = d.simplices[simplex_id] if ( - edges[current_edge]["length"] > 1.5 * d_max and edges[current_edge]["simplices"][simplex_id] > 90 + edges[current_edge]["length"] > 1.5 * d_max + and edges[current_edge]["simplices"][simplex_id] > 90 ) or edges[current_edge]["simplices"][simplex_id] > 180 - 180 / 16: # delete edge and the simplex start @@ -259,7 +260,9 @@ def find_cycles(self): if len(cycles) == 1: geom = Polygon(self.d.points[cycles[0]]) else: - geom = MultiPolygon([Polygon(self.d.points[c]) for c in cycles if len(c) >= 3]) + geom = MultiPolygon( + [Polygon(self.d.points[c]) for c in cycles if len(c) >= 3] + ) except Exception as e: print(e, cycles) return None @@ -329,7 +332,13 @@ def generate_boundaries(df, x="x_location", y="y_location", cell_id="segger_cell res = [] group_df = df.groupby(cell_id) for cell_id, t in tqdm(group_df, total=group_df.ngroups): - res.append({"cell_id": cell_id, "length": len(t), "geom": generate_boundary(t, x=x, y=y)}) + res.append( + { + "cell_id": cell_id, + "length": len(t), + "geom": generate_boundary(t, x=x, y=y), + } + ) return gpd.GeoDataFrame( data=[[b["cell_id"], b["length"]] for b in res], @@ -348,7 +357,9 @@ def generate_boundary(t, x="x_location", y="y_location"): bi.calculate_part_2(plot=False) geom = bi.find_cycles() except Exception as e: - print(f"Failed to generate a boundary for the set of points of size len(t)={len(t)}") + print( + f"Failed to generate a boundary for the set of points of size len(t)={len(t)}" + ) print("Warning:") print(e) print("Skipping this set") @@ -356,8 +367,6 @@ def generate_boundary(t, x="x_location", y="y_location"): return geom - - def extract_largest_polygon(geom): if isinstance(geom, MultiPolygon): return max(geom.geoms, key=lambda p: p.area) # Keep the largest polygon diff --git a/src/segger/prediction/predict_new.py b/src/segger/prediction/predict_new.py index 8e49298a..0a67115b 100644 --- a/src/segger/prediction/predict_new.py +++ b/src/segger/prediction/predict_new.py @@ -188,8 +188,6 @@ def sort_order(c): return lit_segger - - def get_similarity_scores( model: torch.nn.Module, batch: Batch, @@ -242,7 +240,8 @@ def get_normalized_embedding(x: torch.Tensor, key: str) -> torch.Tensor: if is_1d: x = x.unsqueeze(1) emb = ( - model.tx_embedding[key]((x.sum(-1).int())) if is_1d + model.tx_embedding[key]((x.sum(-1).int())) + if is_1d else model.lin0[key](x.float()) ) return F.normalize(emb, p=2, dim=1) @@ -252,8 +251,7 @@ def get_normalized_embedding(x: torch.Tensor, key: str) -> torch.Tensor: embeddings = model(batch.x_dict, batch.edge_index_dict) else: embeddings = { - key: get_normalized_embedding(x, key) - for key, x in batch.x_dict.items() + key: get_normalized_embedding(x, key) for key, x in batch.x_dict.items() } def sparse_multiply( @@ -264,14 +262,20 @@ def sparse_multiply( """ Compute sparse similarity scores using torch only (no cupy). """ - padded_emb = F.pad(embeddings[to_type], (0, 0, 0, 1)) # Add dummy row for -1 padding + padded_emb = F.pad( + embeddings[to_type], (0, 0, 0, 1) + ) # Add dummy row for -1 padding neighbor_embs = padded_emb[edge_index] # [num_from, k, dim] source_embs = embeddings[from_type].unsqueeze(1) # [num_from, 1, dim] similarity = (neighbor_embs * source_embs).sum(dim=-1) # [num_from, k] valid_mask = edge_index != -1 - row_idx = torch.arange(edge_index.size(0), device=device).unsqueeze(1).expand_as(edge_index) + row_idx = ( + torch.arange(edge_index.size(0), device=device) + .unsqueeze(1) + .expand_as(edge_index) + ) row_valid = row_idx[valid_mask] col_valid = edge_index[valid_mask] val_valid = similarity[valid_mask] @@ -280,7 +284,9 @@ def sparse_multiply( val_valid = torch.sigmoid(val_valid) indices = torch.stack([row_valid, col_valid], dim=0) - return torch.sparse_coo_tensor(indices, val_valid, shape, device=device).coalesce() + return torch.sparse_coo_tensor( + indices, val_valid, shape, device=device + ).coalesce() return sparse_multiply(embeddings, dense_index, shape) @@ -383,10 +389,12 @@ def _get_id(): sub_v = sub_v[valid] if sub_i.numel() > 0: - edge_df = pd.DataFrame({ - "source": [transcript_id[i.item()] for i in sub_i[0]], - "target": [transcript_id[i.item()] for i in sub_i[1]], - }) + edge_df = pd.DataFrame( + { + "source": [transcript_id[i.item()] for i in sub_i[0]], + "target": [transcript_id[i.item()] for i in sub_i[1]], + } + ) edge_index_ddf = delayed(dd.from_pandas)(edge_df, npartitions=1) delayed_write_edge_index = delayed(edge_index_ddf.to_parquet)( @@ -598,7 +606,9 @@ def segment( print(f"Mapping component labels...") def _get_id(): - return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" + return ( + "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" + ) new_ids = np.array([_get_id() for _ in range(n)]) comp_labels = new_ids[comps] @@ -608,7 +618,9 @@ def _get_id(): unassigned_transcripts_df = transcripts_df_filtered.loc[ unassigned_mask, ["transcript_id"] ] - new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels) + new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map( + comp_labels + ) unassigned_transcripts_df = unassigned_transcripts_df.assign( segger_cell_id=new_segger_cell_ids ) diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index 8214ea41..a7c97855 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -255,7 +255,8 @@ def get_normalized_embedding(x, key): if is_1d: x = x.unsqueeze(1) embed = ( - model.tx_embedding[key]((x.sum(-1).int())) if is_1d + model.tx_embedding[key]((x.sum(-1).int())) + if is_1d else model.lin0[key](x.float()) ) embed = F.normalize(embed, p=2, dim=1) @@ -266,7 +267,8 @@ def get_normalized_embedding(x, key): embeddings = model(batch.x_dict, batch.edge_index_dict) else: # to go with the inital embeddings for tx-tx embeddings = { - key: get_normalized_embedding(x, key) for key, x in batch.x_dict.items() + key: get_normalized_embedding(x, key) + for key, x in batch.x_dict.items() } def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: diff --git a/src/segger/training/train.py b/src/segger/training/train.py index e0c5de03..6a4163e4 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -147,10 +147,9 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # return optimizer - def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-4) return optimizer def on_before_optimizer_step(self, optimizer): - torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) \ No newline at end of file + torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) diff --git a/src/segger/validation/xenium_explorer.py b/src/segger/validation/xenium_explorer.py index 220f9f0b..c118e333 100644 --- a/src/segger/validation/xenium_explorer.py +++ b/src/segger/validation/xenium_explorer.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt from tqdm import tqdm from typing import Dict, Any, Optional, List, Tuple + # from segger.prediction.boundary import generate_boundary from zarr.storage import ZipStore import zarr @@ -17,17 +18,21 @@ def generate_boundary(seg_cell): """Generate convex hull boundary for a cell""" # Your existing implementation - points = seg_cell[['x_location', 'y_location']].values + points = seg_cell[["x_location", "y_location"]].values if len(points) < 3: return None try: from scipy.spatial import ConvexHull + hull = ConvexHull(points) return Polygon(points[hull.vertices]) except: return None -def get_flatten_version(polygon_vertices: List[List[Tuple[float, float]]], max_value: int = 21) -> np.ndarray: + +def get_flatten_version( + polygon_vertices: List[List[Tuple[float, float]]], max_value: int = 21 +) -> np.ndarray: """Standardize list of polygon vertices to a fixed shape. Args: @@ -43,7 +48,7 @@ def get_flatten_version(polygon_vertices: List[List[Tuple[float, float]]], max_v pass if isinstance(vertices, np.ndarray): vertices = vertices.tolist() - + if len(vertices) > max_value: flattened.append(vertices[:max_value]) else: @@ -112,7 +117,9 @@ def seg2explorer( nucleus_convex_hull = None if len(seg_nucleous) >= 3: try: - nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]]) + nucleus_convex_hull = ConvexHull( + seg_nucleous[["x_location", "y_location"]] + ) except Exception: pass @@ -137,7 +144,8 @@ def seg2explorer( seg_nucleous[["x_location", "y_location"]].values[ nucleus_convex_hull.vertices ] - if nucleus_convex_hull else np.array([[], []]).T + if nucleus_convex_hull + else np.array([[], []]).T ) seg_mask_value.append(uint_cell_id) @@ -162,7 +170,9 @@ def seg2explorer( "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), } - source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") # added this line + source_zarr_store = ZipStore( + source_path / "cells.zarr.zip", mode="r" + ) # added this line existing_store = zarr.open(source_zarr_store, mode="r") new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") @@ -175,7 +185,9 @@ def seg2explorer( # Reshape cell polygons to (n_cells, 50) format n_cells = cell_polygon_vertices.shape[0] - cell_vertices_flat = cell_polygon_vertices.reshape(n_cells, -1)[:, :257] # Take first 50 values + cell_vertices_flat = cell_polygon_vertices.reshape(n_cells, -1)[ + :, :257 + ] # Take first 50 values set1 = polygon_group.create_group("1") set1["cell_index"] = np.arange(1, n_cells + 1, dtype=np.uint32) # 1-based indexing @@ -213,7 +225,9 @@ def seg2explorer( new_zarr.create_group("/cell_groups") for i, cluster in enumerate(clusters_names): new_zarr["cell_groups"].create_group(str(i)) - group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] + group_values = [ + clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster] + ] indices, indptr = get_indices_indptr(np.array(group_values)) new_zarr["cell_groups"][str(i)]["indices"] = indices new_zarr["cell_groups"][str(i)]["indptr"] = indptr @@ -240,7 +254,6 @@ def seg2explorer( ) - def str_to_uint32(cell_id_str: str) -> Tuple[int, int]: """Convert a string cell ID back to uint32 format. @@ -484,11 +497,12 @@ def generate_experiment_file( json.dump(experiment, f, indent=2) - - -from pqdm.processes import pqdm # or from pqdm.processes import pqdm for process backend +from pqdm.processes import ( + pqdm, +) # or from pqdm.processes import pqdm for process backend import os + def _process_one_cell(args): seg_cell_id, seg_cell, area_low, area_high = args @@ -506,21 +520,21 @@ def _process_one_cell(args): cell_vertices = list(cell_convex_hull.exterior.coords) if cell_vertices[0] == cell_vertices[-1]: cell_vertices = cell_vertices[:-1] - + n_vertices = len(cell_vertices) - + # Sample up to 16 vertices if n_vertices > 16: # Evenly sample 16 vertices from original set - indices = np.linspace(0, n_vertices-1, 16, dtype=int) + indices = np.linspace(0, n_vertices - 1, 16, dtype=int) sampled_vertices = [cell_vertices[i] for i in indices] else: sampled_vertices = cell_vertices - + # Pad with first vertex if needed if len(sampled_vertices) < 16: sampled_vertices += [sampled_vertices[0]] * (16 - len(sampled_vertices)) - + return { "seg_cell_id": seg_cell_id, "cell_area": float(cell_convex_hull.area), @@ -541,7 +555,7 @@ def seg2explorer_pqdm( cell_id_columns: str = "seg_cell_id", area_low: float = 10, area_high: float = 100, - n_jobs: int = 1 + n_jobs: int = 1, ) -> None: source_path = Path(source_path) storage = Path(output_dir) @@ -551,11 +565,20 @@ def seg2explorer_pqdm( # Build a lightweight iterable of work items (id, slice, thresholds) # NOTE: this will still materialize each group slice, but we avoid copying the whole DF per worker. - work_iter = ((seg_cell_id, seg_cell, area_low, area_high) for seg_cell_id, seg_cell in grouped_by) + work_iter = ( + (seg_cell_id, seg_cell, area_low, area_high) + for seg_cell_id, seg_cell in grouped_by + ) # Parallel map with threads (good default). Tune n_jobs. # n_jobs = min(32, os.cpu_count() or 8) - results = pqdm(work_iter, _process_one_cell, n_jobs=n_jobs, desc="Cells", exception_behaviour="immediate") + results = pqdm( + work_iter, + _process_one_cell, + n_jobs=n_jobs, + desc="Cells", + exception_behaviour="immediate", + ) # Collate results cell_id2old_id: Dict[int, Any] = {} @@ -575,7 +598,9 @@ def seg2explorer_pqdm( # Flatten vertices exactly as before cell_polygon_vertices = get_flatten_version(polygon_vertices) - source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") # added this line + source_zarr_store = ZipStore( + source_path / "cells.zarr.zip", mode="r" + ) # added this line existing_store = zarr.open(source_zarr_store, mode="r") new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") @@ -584,11 +609,13 @@ def seg2explorer_pqdm( # Process cell polygons (set 1) # cell_polygons = cells["polygon_vertices"][1] # Cell polygons are at index 1 - cell_num_vertices = polygon_num_vertices # Cell vertex counts + cell_num_vertices = polygon_num_vertices # Cell vertex counts # Reshape cell polygons to (n_cells, 50) format n_cells = cell_polygon_vertices.shape[0] - cell_vertices_flat = cell_polygon_vertices.reshape(n_cells, -1)[:, :33] # Take first 50 values + cell_vertices_flat = cell_polygon_vertices.reshape(n_cells, -1)[ + :, :33 + ] # Take first 50 values set1 = polygon_group.create_group("1") set1["cell_index"] = np.arange(1, n_cells + 1, dtype=np.uint32) # 1-based indexing @@ -626,7 +653,9 @@ def seg2explorer_pqdm( new_zarr.create_group("/cell_groups") for i, cluster in enumerate(clusters_names): new_zarr["cell_groups"].create_group(str(i)) - group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] + group_values = [ + clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster] + ] indices, indptr = get_indices_indptr(np.array(group_values)) new_zarr["cell_groups"][str(i)]["indices"] = indices new_zarr["cell_groups"][str(i)]["indptr"] = indptr