Skip to content

Commit b717064

Browse files
Merge pull request #171 from MannLabs/deprecate_cell_id_tracking
improve spatialdata file handling
2 parents aac3456 + 5e765ac commit b717064

File tree

7 files changed

+152
-246
lines changed

7 files changed

+152
-246
lines changed

src/scportrait/pipeline/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class ProcessingStep(Logable):
153153
DEFAULT_SEG_NAME_0 = "nucleus"
154154
DEFAULT_SEG_NAME_1 = "cytosol"
155155

156-
DEFAULT_CENTERS_NAME = "centers_cells"
156+
DEFAULT_CENTERS_NAME = "centers"
157157

158158
DEFAULT_CHUNK_SIZE = (1, 1000, 1000)
159159

src/scportrait/pipeline/_utils/sdata_io.py

Lines changed: 89 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import numpy as np
99
import xarray
1010
from alphabase.io import tempmmap
11+
from anndata import AnnData
1112
from spatialdata import SpatialData
12-
from spatialdata.models import Image2DModel, PointsModel, TableModel
13+
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel
1314
from spatialdata.transformations.transformations import Identity
1415

1516
from scportrait.pipeline._base import Logable
16-
from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel
1717
from scportrait.pipeline._utils.spatialdata_helper import (
1818
calculate_centroids,
1919
get_chunk_size,
@@ -87,13 +87,6 @@ def _read_sdata(self) -> SpatialData:
8787
_sdata = self._create_empty_sdata()
8888
_sdata.write(self.sdata_path, overwrite=True)
8989

90-
allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
91-
for key in _sdata.labels:
92-
if key in allowed_labels:
93-
segmentation_object = _sdata.labels[key]
94-
if not hasattr(segmentation_object.attrs, "cell_ids"):
95-
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)
96-
9790
return _sdata
9891

9992
def get_sdata(self) -> SpatialData:
@@ -194,49 +187,47 @@ def _write_image_sdata(
194187
# check if the image is already a multi-scale image
195188
if isinstance(image, xarray.DataTree):
196189
# if so only validate the model since this means we are getting the image from a spatialdata object already
197-
# image = Image2DModel.validate(image)
198-
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
190+
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
191+
Image2DModel().validate(image)
199192
if scale_factors is not None:
200193
Warning("Scale factors are ignored when passing a multi-scale image.")
201-
image = image.scale0.image
202-
203-
if scale_factors is None:
204-
scale_factors = [2, 4, 8]
205-
if scale_factors is None:
206-
scale_factors = [2, 4, 8]
207-
208-
if isinstance(image, xarray.DataArray):
209-
# if so first validate the model since this means we are getting the image from a spatialdata object already
210-
# then apply the scales transform
211-
# image = Image2DModel.validate(image)
212-
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
213-
214-
if channel_names is not None:
215-
Warning(
216-
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
194+
else:
195+
if scale_factors is None:
196+
scale_factors = [2, 4, 8]
197+
if scale_factors is None:
198+
scale_factors = [2, 4, 8]
199+
200+
if isinstance(image, xarray.DataArray):
201+
# if so first validate the model since this means we are getting the image from a spatialdata object already
202+
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
203+
Image2DModel().validate(image)
204+
205+
if channel_names is not None:
206+
Warning(
207+
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
208+
)
209+
210+
image = Image2DModel.parse(
211+
image,
212+
scale_factors=scale_factors,
213+
rgb=False,
217214
)
218215

219-
image = Image2DModel.parse(
220-
image,
221-
scale_factors=scale_factors,
222-
rgb=False,
223-
)
224-
225-
else:
226-
if channel_names is None:
227-
channel_names = [f"channel_{i}" for i in range(image.shape[0])]
228-
229-
# transform to spatialdata image model
230-
transform_original = Identity()
231-
image = Image2DModel.parse(
232-
image,
233-
dims=["c", "y", "x"],
234-
chunks=chunks,
235-
c_coords=channel_names,
236-
scale_factors=scale_factors,
237-
transformations={"global": transform_original},
238-
rgb=False,
239-
)
216+
else:
217+
if channel_names is None:
218+
channel_names = [f"channel_{i}" for i in range(image.shape[0])]
219+
220+
# transform to spatialdata image model
221+
transform_original = Identity()
222+
image = Image2DModel.parse(
223+
image,
224+
dims=["c", "y", "x"],
225+
chunks=chunks,
226+
c_coords=channel_names,
227+
scale_factors=scale_factors,
228+
transformations={"global": transform_original},
229+
rgb=False,
230+
)
240231

241232
if overwrite:
242233
self._force_delete_object(_sdata, image_name, "images")
@@ -249,7 +240,7 @@ def _write_image_sdata(
249240

250241
def _write_segmentation_object_sdata(
251242
self,
252-
segmentation_object: spLabels2DModel,
243+
segmentation_object: Labels2DModel,
253244
segmentation_label: str,
254245
classes: set[str] | None = None,
255246
overwrite: bool = False,
@@ -264,9 +255,6 @@ def _write_segmentation_object_sdata(
264255
"""
265256
_sdata = self._read_sdata()
266257

267-
if not hasattr(segmentation_object.attrs, "cell_ids"):
268-
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=classes)
269-
270258
if overwrite:
271259
self._force_delete_object(_sdata, segmentation_label, "labels")
272260

@@ -294,7 +282,7 @@ def _write_segmentation_sdata(
294282
overwrite: Whether to overwrite existing data
295283
"""
296284
transform_original = Identity()
297-
mask = spLabels2DModel.parse(
285+
mask = Labels2DModel.parse(
298286
segmentation,
299287
dims=["y", "x"],
300288
transformations={"global": transform_original},
@@ -324,6 +312,48 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over
324312

325313
self.log(f"Points {points_name} written to sdata object.")
326314

315+
def _write_table_sdata(
316+
self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False
317+
) -> None:
318+
"""Write anndata to SpatialData.
319+
320+
Args:
321+
adata: AnnData object to write
322+
table_name: Name for the table object under which it should be saved
323+
segmentation_mask_name: Name of the segmentation mask that this table annotates
324+
overwrite: Whether to overwrite existing data
325+
326+
Returns:
327+
None (writes to sdata object)
328+
"""
329+
_sdata = self._read_sdata()
330+
331+
assert isinstance(adata, AnnData), "Input data must be an AnnData object."
332+
assert segmentation_mask_name in _sdata.labels, "Segmentation mask not found in sdata object."
333+
334+
# get obs and obs_indices
335+
obs = adata.obs
336+
obs_indices = adata.obs.index.astype(int) # need to ensure int subtype to be able to annotate seg masks
337+
338+
# sanity checking
339+
assert len(obs_indices) == len(set(obs_indices)), "Instance IDs are not unique."
340+
cell_ids_mask = set(_sdata[f"{self.centers_name}_{segmentation_mask_name}"].index.values.compute())
341+
assert set(obs_indices).issubset(cell_ids_mask), "Instance IDs do not match segmentation mask cell IDs."
342+
343+
obs["instance_id"] = obs_indices
344+
obs["region"] = segmentation_mask_name
345+
obs["region"] = obs["region"].astype("category")
346+
347+
adata.obs = obs
348+
table = TableModel.parse(
349+
adata,
350+
region=[segmentation_mask_name],
351+
region_key="region",
352+
instance_key="instance_id",
353+
)
354+
355+
self._write_table_object_sdata(table, table_name, overwrite=overwrite)
356+
327357
def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrite: bool = False) -> None:
328358
"""Write table object to SpatialData.
329359
@@ -358,7 +388,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod
358388
if segmentation_label not in sdata.labels:
359389
raise ValueError(f"Segmentation {segmentation_label} not found in sdata object.")
360390

361-
centers = calculate_centroids(sdata.labels[segmentation_label])
391+
mask = sdata.labels[segmentation_label]
392+
if isinstance(mask, xarray.DataTree):
393+
mask = mask.scale0.image
394+
centers = calculate_centroids(mask)
362395
return centers
363396

364397
def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None:
@@ -370,7 +403,8 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
370403
"""
371404
_sdata = self._read_sdata()
372405
centroids_object = self._get_centers(_sdata, segmentation_label)
373-
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)
406+
centers_name = f"{self.centers_name}_{segmentation_label}"
407+
self._write_points_object_sdata(centroids_object, centers_name, overwrite=overwrite)
374408

375409
## load elements from sdata to a memory mapped array
376410
def _load_input_image_to_memmap(

src/scportrait/pipeline/_utils/spatialdata_classes.py

Lines changed: 0 additions & 108 deletions
This file was deleted.

src/scportrait/pipeline/extraction.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,7 @@ def _get_segmentation_info(self):
324324
# this mask will be used to calculate the cell centers
325325
if self.n_masks == 2:
326326
# perform sanity check that the masks have the same ids
327-
assert (
328-
_sdata[self.nucleus_key].attrs["cell_ids"] == _sdata[self.cytosol_key].attrs["cell_ids"]
329-
), "Nucleus and cytosol masks contain different cell ids. Cannot proceed with extraction."
327+
# THIS NEEDS TO BE IMPLEMENTED HERE
330328

331329
self.main_segmenation_mask = self.nucleus_key
332330

@@ -352,27 +350,26 @@ def _get_centers(self):
352350
_sdata = self.filehandler._read_sdata()
353351

354352
# calculate centers if they have not been calculated yet
355-
if self.DEFAULT_CENTERS_NAME not in _sdata:
353+
centers_name = f"{self.DEFAULT_CENTERS_NAME}_{self.main_segmenation_mask}"
354+
if centers_name not in _sdata:
356355
self.filehandler._add_centers(self.main_segmenation_mask, overwrite=self.overwrite)
357356
_sdata = self.filehandler._read_sdata() # reread to ensure we have updated version
358357

359-
centers = _sdata[self.DEFAULT_CENTERS_NAME].values.compute()
358+
centers = _sdata[centers_name].values.compute()
360359

361360
# round to int so that we can use them as indices
362361
centers = np.round(centers).astype(int)
363362

364363
self.centers = centers
365-
self.centers_cell_ids = _sdata[self.DEFAULT_CENTERS_NAME].index.values.compute()
364+
self.centers_cell_ids = _sdata[centers_name].index.values.compute()
366365

367366
# ensure that the centers ids are unique
368367
assert len(self.centers_cell_ids) == len(
369368
set(self.centers_cell_ids)
370369
), "Cell ids in centers are not unique. Cannot proceed with extraction."
371370

372371
# double check that the cell_ids contained in the seg masks match to those from centers
373-
assert set(self.centers_cell_ids) == set(
374-
_sdata[self.main_segmenation_mask].attrs["cell_ids"]
375-
), "Cell ids from centers do not match those from the segmentation mask. Cannot proceed with extraction."
372+
# THIS NEEDS TO BE IMPLEMENTED HERE
376373

377374
def _get_classes_to_extract(self):
378375
if self.partial_processing:

0 commit comments

Comments
 (0)