Skip to content

Commit b76b4aa

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8a778a3 commit b76b4aa

File tree

5 files changed

+56
-72
lines changed

5 files changed

+56
-72
lines changed

scripts/create_data_fast_sample.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
import numpy as np
88
from segger.data.parquet._utils import get_polygons_from_xy
99

10-
xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/')
11-
segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_200_final')
10+
xenium_data_dir = Path("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/")
11+
segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_200_final")
1212

1313

14-
scrnaseq_file = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
15-
celltype_column = 'celltype_minor'
16-
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
17-
sc.read(scrnaseq_file),
18-
celltype_column
19-
)
14+
scrnaseq_file = Path("/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad")
15+
celltype_column = "celltype_minor"
16+
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(sc.read(scrnaseq_file), celltype_column)
2017

2118
sample = STSampleParquet(
2219
base_dir=xenium_data_dir,
@@ -43,30 +40,29 @@
4340

4441

4542
sample.save(
46-
data_dir=segger_data_dir,
47-
k_bd=3,
48-
dist_bd=15,
49-
k_tx=3,
50-
dist_tx=5,
51-
tile_width=200,
52-
tile_height=200,
53-
neg_sampling_ratio=5.0,
54-
frac=1.0,
55-
val_prob=0.3,
56-
test_prob=0,
43+
data_dir=segger_data_dir,
44+
k_bd=3,
45+
dist_bd=15,
46+
k_tx=3,
47+
dist_tx=5,
48+
tile_width=200,
49+
tile_height=200,
50+
neg_sampling_ratio=5.0,
51+
frac=1.0,
52+
val_prob=0.3,
53+
test_prob=0,
5754
)
5855

5956

60-
xenium_data_dir = Path('data_tidy/bc_5k')
61-
segger_data_dir = Path('data_tidy/pyg_datasets/bc_5k_emb_new')
62-
57+
xenium_data_dir = Path("data_tidy/bc_5k")
58+
segger_data_dir = Path("data_tidy/pyg_datasets/bc_5k_emb_new")
6359

6460

6561
sample = STSampleParquet(
6662
base_dir=xenium_data_dir,
6763
n_workers=8,
68-
sample_type='xenium',
69-
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
64+
sample_type="xenium",
65+
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
7066
)
7167

7268

@@ -88,16 +84,14 @@
8884

8985

9086
sample.save(
91-
data_dir=segger_data_dir,
92-
k_bd=3,
93-
dist_bd=15.0,
94-
k_tx=15,
95-
dist_tx=3,
96-
tile_size=50_000,
97-
neg_sampling_ratio=5.0,
98-
frac=0.1,
99-
val_prob=0.1,
100-
test_prob=0.1,
87+
data_dir=segger_data_dir,
88+
k_bd=3,
89+
dist_bd=15.0,
90+
k_tx=15,
91+
dist_tx=3,
92+
tile_size=50_000,
93+
neg_sampling_ratio=5.0,
94+
frac=0.1,
95+
val_prob=0.1,
96+
test_prob=0.1,
10197
)
102-
103-

scripts/predict_model_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
seg_tag = "bc_fast_data_emb_major"
2323
model_version = 1
2424

25-
segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag
26-
models_dir = Path("./models") / seg_tag
25+
segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag
26+
models_dir = Path("./models") / seg_tag
2727
benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc")
2828
transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet"
2929
# Initialize the Lightning data module

scripts/train_model_sample.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616

1717

18-
segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_final_200')
18+
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_final_200")
1919
models_dir = Path("./models/bc_rep1_emb_final_200")
2020

2121
# Base directory to store Pytorch Lightning models
@@ -35,37 +35,34 @@
3535

3636
# If you use custom gene embeddings, use the following two lines instead:
3737
is_token_based = False
38-
num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes
38+
num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes
3939

4040

4141
num_bd_features = dm.train[0].x_dict["bd"].shape[1]
4242

4343
# Initialize the Lightning model
4444
ls = LitSegger(
45-
is_token_based = is_token_based,
46-
num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features},
47-
init_emb=8,
45+
is_token_based=is_token_based,
46+
num_node_features={"tx": num_tx_tokens, "bd": num_bd_features},
47+
init_emb=8,
4848
hidden_channels=64,
4949
out_channels=16,
5050
heads=4,
5151
num_mid_layers=3,
52-
aggr='sum',
53-
learning_rate=1e-3
52+
aggr="sum",
53+
learning_rate=1e-3,
5454
)
5555

5656
# Initialize the Lightning trainer
5757
trainer = Trainer(
58-
accelerator='cuda',
59-
strategy='auto',
60-
precision='16-mixed',
61-
devices=2, # set higher number if more gpus are available
58+
accelerator="cuda",
59+
strategy="auto",
60+
precision="16-mixed",
61+
devices=2, # set higher number if more gpus are available
6262
max_epochs=400,
6363
default_root_dir=models_dir,
6464
logger=CSVLogger(models_dir),
6565
)
6666

6767

68-
trainer.fit(
69-
model=ls,
70-
datamodule=dm
71-
)
68+
trainer.fit(model=ls, datamodule=dm)

src/segger/data/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def try_import(module_name):
4343
from datetime import timedelta
4444

4545

46-
def filter_transcripts( #ONLY FOR XENIUM
46+
def filter_transcripts( # ONLY FOR XENIUM
4747
transcripts_df: pd.DataFrame,
4848
min_qv: float = 20.0,
4949
) -> pd.DataFrame:
@@ -65,14 +65,14 @@ def filter_transcripts( #ONLY FOR XENIUM
6565
"DeprecatedCodeword_",
6666
"UnassignedCodeword_",
6767
)
68-
69-
transcripts_df['feature_name'] = transcripts_df['feature_name'].apply(
68+
69+
transcripts_df["feature_name"] = transcripts_df["feature_name"].apply(
7070
lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
7171
)
72-
mask_quality = transcripts_df['qv'] >= min_qv
72+
mask_quality = transcripts_df["qv"] >= min_qv
7373

7474
# Apply the filter for unwanted codewords using Dask string functions
75-
mask_codewords = ~transcripts_df['feature_name'].str.startswith(filter_codewords)
75+
mask_codewords = ~transcripts_df["feature_name"].str.startswith(filter_codewords)
7676

7777
# Combine the filters and return the filtered Dask DataFrame
7878
mask = mask_quality & mask_codewords

src/segger/prediction/predict_parquet.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
from pathlib import Path
1414
from torch_geometric.loader import DataLoader
1515
from torch_geometric.data import Batch
16-
from segger.data.utils import (
17-
get_edge_index,
18-
format_time,
19-
create_anndata,
20-
coo_to_dense_adj,
21-
filter_transcripts
22-
)
16+
from segger.data.utils import get_edge_index, format_time, create_anndata, coo_to_dense_adj, filter_transcripts
2317
from segger.training.train import LitSegger
2418
from segger.training.segger_data_module import SeggerDataModule
2519
from segger.prediction.boundary import generate_boundaries
@@ -544,13 +538,12 @@ def segment(
544538
if verbose:
545539
print(f"Applying max score selection logic...")
546540
output_ddf_save_path = save_dir / "transcripts_df.parquet"
547-
548-
541+
549542
seg_final_dd = pd.read_parquet(output_ddf_save_path)
550-
551-
seg_final_filtered = seg_final_dd.sort_values(
552-
"score", ascending=False
553-
).drop_duplicates(subset="transcript_id", keep="first")
543+
544+
seg_final_filtered = seg_final_dd.sort_values("score", ascending=False).drop_duplicates(
545+
subset="transcript_id", keep="first"
546+
)
554547

555548
if verbose:
556549
elapsed_time = time() - step_start_time
@@ -570,7 +563,7 @@ def segment(
570563

571564
# Outer merge to include all transcripts, even those without assigned cell ids
572565
transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer")
573-
566+
574567
if verbose:
575568
elapsed_time = time() - step_start_time
576569
print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.")

0 commit comments

Comments
 (0)