Skip to content

Commit 8297ead

Browse files
authored
Comseg fixes (#58)
1 parent 6cf993b commit 8297ead

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

src/methods_transcript_assignment/comseg/config.vsh.yaml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
__merge__: /src/api/comp_method_transcript_assignment.yaml
2+
13
name: comseg
24
label: "ComSeg Segmentation"
35
summary: "Spatial segmentation using ComSeg method"
@@ -11,21 +13,16 @@ links:
1113
references:
1214
doi: "10.1038/s41592-020-01018-x"
1315

14-
__merge__: /src/api/comp_method_segmentation.yaml
1516

1617
arguments:
1718
- name: --transcripts_key
1819
type: string
1920
default: "transcripts"
2021
description: "Key for transcripts in the points layer"
21-
- name: --shapes_key
22-
type: string
23-
default: "cell_boundaries"
24-
description: "Key for cell boundaries in the shapes layer"
25-
- name: --images_key
22+
- name: --coordinate_system
2623
type: string
27-
default: "morphology_mip"
28-
description: "Key for morphology image in the images layer"
24+
default: "global"
25+
description: "Coordinate system for the transcripts"
2926
- name: --patch_width
3027
type: integer
3128
default: 1200

src/methods_transcript_assignment/comseg/script.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1+
import dask
2+
import xarray as xr
13
import spatialdata as sd
24
import sopa
35
import anndata as ad
46
import pandas as pd
57
import numpy as np
6-
from scipy import sparse
78

89
## VIASH START
910
par = {
10-
"input": "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr",
11-
"output": "transcripts.zarr",
12-
11+
"input_ist": "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr",
12+
"input_segmentation": "resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr",
1313
"transcripts_key": "transcripts",
14-
"shapes_key": "cell_boundaries",
15-
"images_key": "morphology_mip",
14+
"coordinate_system": "global",
15+
"output": "temp/comseg/transcripts.zarr",
16+
1617
"patch_width": 1200,
1718
"patch_overlap": 50,
1819
"transcript_patch_width": 200,
@@ -72,19 +73,36 @@ def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ):
7273
return adata
7374

7475

75-
# Read input SpatialData
76-
sdata = sd.read_zarr(par["input"])
76+
77+
# Read input files
78+
print('Reading input files', flush=True)
79+
sdata = sd.read_zarr(par['input_ist'])
80+
sdata_segm = sd.read_zarr(par['input_segmentation'])
81+
82+
83+
# Convert the prior segmentation to polygons
84+
if isinstance(sdata_segm["segmentation"], xr.DataTree):
85+
shapes_gdf = sopa.shapes.vectorize(sdata_segm["segmentation"]["scale0"].image)
86+
else:
87+
shapes_gdf = sopa.shapes.vectorize(sdata_segm["segmentation"])
88+
89+
sdata["segmentation_boundaries"] = sd.models.ShapesModel.parse(
90+
shapes_gdf, transformations=sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).copy()
91+
)
92+
93+
# Make patches
7794
sopa.make_image_patches(sdata, patch_width=par["patch_width"], patch_overlap=par["patch_overlap"])
7895

7996
transcript_patch_args = {
8097
"sdata": sdata,
8198
"write_cells_centroids": True,
8299
"patch_width": par["transcript_patch_width"],
100+
"prior_shapes_key": "segmentation_boundaries",
83101
}
84-
transcript_patch_args["prior_shapes_key"] = par["shapes_key"]
85102

86103
sopa.make_transcript_patches(**transcript_patch_args)
87104

105+
# Run ComSeg
88106
config = {
89107
"dict_scale": {"x": 1, "y": 1, "z": 1},
90108
"mean_cell_diameter": par["mean_cell_diameter"],
@@ -96,21 +114,45 @@ def _add_csr(X_partitions, geo_df, partition, gene_column, gene_names ):
96114
"gene_column": par["gene_column"],
97115
}
98116

99-
100117
sopa.aggregation.transcripts._count_transcripts_aligned = fixed_count_transcripts_aligned
118+
# sopa.settings.parallelization_backend = 'dask'
101119
sopa.segmentation.comseg(sdata, config)
102120

121+
# Assign transcripts to cell ids
122+
sopa.spatial.assign_transcript_to_cell(
123+
sdata,
124+
points_key="transcripts",
125+
shapes_key="comseg_boundaries",
126+
key_added="cell_id",
127+
unassigned_value=0
128+
)
129+
103130
# Create output SpatialData
104-
sd_output = sd.SpatialData()
105131

106-
cell_id_col = sdata["transcripts"][f"cell_id"]
107-
sdata.tables["table"]=ad.AnnData(obs=pd.DataFrame({"cell_id":cell_id_col}), var=sdata.tables["table"].var[[]])
108-
sdata_new = sd.SpatialData(
109-
points=sdata.points,
110-
tables=sdata.tables
111-
)
132+
# Create objects for cells table
133+
print('Creating objects for cells table', flush=True)
134+
unique_cells = np.unique(sdata["transcripts"]["cell_id"])
135+
zero_idx = np.where(unique_cells == 0)
136+
if len(zero_idx[0]):
137+
unique_cells=np.delete(unique_cells, zero_idx[0][0])
138+
cell_id_col = pd.Series(unique_cells, name='cell_id', index=unique_cells)
139+
140+
# Create transcripts only sdata
141+
print('Subsetting to transcripts cell id data', flush=True)
142+
sdata_transcripts_only = sd.SpatialData(
143+
points={
144+
"transcripts": sdata['transcripts']
145+
},
146+
tables={
147+
"table": ad.AnnData(
148+
obs=pd.DataFrame(cell_id_col),
149+
var=sdata.tables["table"].var[[]]
150+
)
151+
}
152+
)
153+
112154

113155
output_path = par['output']
114-
sdata_new.write(output_path, overwrite=True)
156+
sdata_transcripts_only.write(output_path, overwrite=True)
115157

116158

0 commit comments

Comments
 (0)