Skip to content

Commit ca85a33

Browse files
Comseg integration (#45)
1 parent ecb67b5 commit ca85a33

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
name: comseg
2+
label: "ComSeg Segmentation"
3+
summary: "Spatial segmentation using ComSeg method"
4+
description: |
5+
ComSeg is a spatial transcriptomics segmentation method that uses transcript locations
6+
and morphological information to define cell boundaries. It is particularly effective
7+
for high-resolution spatial transcriptomics data.
8+
links:
9+
documentation: "https://github.com/openproblems-bio/task_ist_preprocessing"
10+
repository: "https://github.com/openproblems-bio/task_ist_preprocessing"
11+
references:
12+
doi: "10.1038/s41592-020-01018-x"
13+
14+
__merge__: /src/api/comp_method_segmentation.yaml
15+
16+
arguments:
17+
- name: --transcripts_key
18+
type: string
19+
default: "transcripts"
20+
description: "Key for transcripts in the points layer"
21+
- name: --shapes_key
22+
type: string
23+
default: "cell_boundaries"
24+
description: "Key for cell boundaries in the shapes layer"
25+
- name: --images_key
26+
type: string
27+
default: "morphology_mip"
28+
description: "Key for morphology image in the images layer"
29+
- name: --patch_width
30+
type: integer
31+
default: 1200
32+
description: "Width of image patches for processing"
33+
- name: --patch_overlap
34+
type: integer
35+
default: 50
36+
description: "Overlap between patches"
37+
- name: --transcript_patch_width
38+
type: integer
39+
default: 200
40+
description: "Width of transcript patches"
41+
- name: --mean_cell_diameter
42+
type: double
43+
default: 15.0
44+
description: "Expected mean cell diameter in micrometers"
45+
- name: --max_cell_radius
46+
type: double
47+
default: 25.0
48+
description: "Maximum cell radius in micrometers"
49+
- name: --alpha
50+
type: double
51+
default: 0.5
52+
description: "Alpha parameter for ComSeg algorithm"
53+
- name: --min_rna_per_cell
54+
type: integer
55+
default: 5
56+
description: "Minimum number of transcripts per cell"
57+
- name: --gene_column
58+
type: string
59+
default: "feature_name"
60+
description: "Column name for gene identifiers in transcripts data"
61+
- name: --norm_vector
62+
type: boolean
63+
default: false
64+
description: "Whether to normalize vectors in ComSeg"
65+
- name: --allow_disconnected_polygon
66+
type: boolean
67+
default: true
68+
description: "Allow disconnected polygons in segmentation"
69+
70+
71+
resources:
72+
- type: python_script
73+
path: script.py
74+
75+
76+
engines:
77+
- type: docker
78+
image: openproblems/base_python:1
79+
setup:
80+
- type: python
81+
pypi:
82+
- spatialdata
83+
- sopa
84+
- anndata
85+
- pandas
86+
- numpy
87+
- xarray
88+
- scikit-image
89+
- comseg
90+
- scipy
91+
92+
- type: native
93+
94+
runners:
95+
- type: executable
96+
- type: nextflow
97+
directives:
98+
label: [ hightime, midcpu, highmem ]
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import spatialdata as sd
2+
import sopa
3+
import anndata as ad
4+
import pandas as pd
5+
import numpy as np
6+
from scipy import sparse
7+
8+
## VIASH START
9+
par = {
10+
"input": "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr",
11+
"output": "transcripts.zarr",
12+
13+
"transcripts_key": "transcripts",
14+
"shapes_key": "cell_boundaries",
15+
"images_key": "morphology_mip",
16+
"patch_width": 1200,
17+
"patch_overlap": 50,
18+
"transcript_patch_width": 200,
19+
"mean_cell_diameter": 15.0,
20+
"max_cell_radius": 25.0,
21+
"alpha": 0.5,
22+
"min_rna_per_cell": 5,
23+
"gene_column": "feature_name",
24+
"norm_vector": False,
25+
"allow_disconnected_polygon": True,
26+
}
27+
## VIASH END
28+
29+
def fixed_count_transcripts_aligned(geo_df, points, value_key):
30+
"""
31+
The same function as sopa.aggregation.transcripts._count_transcripts_aligned.
32+
Minor change just the matrix X is converted to csr_matrix, to avoid bug error in comseg call
33+
34+
"""
35+
from scipy.sparse import csr_matrix
36+
from anndata import AnnData
37+
from dask.diagnostics import ProgressBar
38+
from functools import partial
39+
from sopa._settings import settings
40+
import geopandas as gpd
41+
def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ):
42+
if settings.gene_exclude_pattern is not None:
43+
partition = partition[~partition[gene_column].str.match(settings.gene_exclude_pattern, case=False, na=False)]
44+
45+
points_gdf = gpd.GeoDataFrame(partition, geometry=gpd.points_from_xy(partition["x"], partition["y"]))
46+
joined = geo_df.sjoin(points_gdf)
47+
cells_indices, column_indices = joined.index, joined[gene_column].cat.codes
48+
cells_indices = cells_indices[column_indices >= 0]
49+
column_indices = column_indices[column_indices >= 0]
50+
X_partition = csr_matrix((np.full(len(cells_indices), 1), (cells_indices, column_indices)),
51+
shape=(len(geo_df), len(gene_names)),
52+
)
53+
X_partitions.append(X_partition)
54+
55+
56+
points[value_key] = points[value_key].astype("category").cat.as_known()
57+
gene_names = points[value_key].cat.categories.astype(str)
58+
X = csr_matrix((len(geo_df), len(gene_names)), dtype=int)
59+
adata = AnnData(X=X, var=pd.DataFrame(index=gene_names))
60+
adata.obs_names = geo_df.index.astype(str)
61+
geo_df = geo_df.reset_index()
62+
X_partitions = []
63+
with ProgressBar():
64+
points.map_partitions(
65+
partial(_add_csr, X_partitions, geo_df, gene_column=value_key, gene_names=gene_names),
66+
meta=(),
67+
).compute()
68+
for X_partition in X_partitions:
69+
adata.X += X_partition
70+
if settings.gene_exclude_pattern is not None:
71+
adata = adata[:, ~adata.var_names.str.match(settings.gene_exclude_pattern, case=False, na=False)].copy()
72+
return adata
73+
74+
75+
# Read input SpatialData
76+
sdata = sd.read_zarr(par["input"])
77+
sopa.make_image_patches(sdata, patch_width=par["patch_width"], patch_overlap=par["patch_overlap"])
78+
79+
transcript_patch_args = {
80+
"sdata": sdata,
81+
"write_cells_centroids": True,
82+
"patch_width": par["transcript_patch_width"],
83+
}
84+
transcript_patch_args["prior_shapes_key"] = par["shapes_key"]
85+
86+
sopa.make_transcript_patches(**transcript_patch_args)
87+
88+
config = {
89+
"dict_scale": {"x": 1, "y": 1, "z": 1},
90+
"mean_cell_diameter": par["mean_cell_diameter"],
91+
"max_cell_radius": par["max_cell_radius"],
92+
"norm_vector": par["norm_vector"],
93+
"alpha": par["alpha"],
94+
"allow_disconnected_polygon": par["allow_disconnected_polygon"],
95+
"min_rna_per_cell": par["min_rna_per_cell"],
96+
"gene_column": par["gene_column"],
97+
}
98+
99+
100+
sopa.aggregation.transcripts._count_transcripts_aligned = fixed_count_transcripts_aligned
101+
sopa.segmentation.comseg(sdata, config)
102+
103+
# Create output SpatialData
104+
sd_output = sd.SpatialData()
105+
106+
cell_id_col = sdata["transcripts"][f"cell_id"]
107+
sdata.tables["table"]=ad.AnnData(obs=pd.DataFrame({"cell_id":cell_id_col}), var=sdata.tables["table"].var[[]])
108+
sdata_new = sd.SpatialData(
109+
points=sdata.points,
110+
tables=sdata.tables
111+
)
112+
113+
output_path = par['output']
114+
sdata_new.write(output_path, overwrite=True)
115+
116+

0 commit comments

Comments
 (0)