88import numpy as np
99import xarray
1010from alphabase .io import tempmmap
11+ from anndata import AnnData
1112from spatialdata import SpatialData
12- from spatialdata .models import Image2DModel , PointsModel , TableModel
13+ from spatialdata .models import Image2DModel , Labels2DModel , PointsModel , TableModel
1314from spatialdata .transformations .transformations import Identity
1415
1516from scportrait .pipeline ._base import Logable
16- from scportrait .pipeline ._utils .spatialdata_classes import spLabels2DModel
1717from scportrait .pipeline ._utils .spatialdata_helper import (
1818 calculate_centroids ,
1919 get_chunk_size ,
@@ -87,13 +87,6 @@ def _read_sdata(self) -> SpatialData:
8787 _sdata = self ._create_empty_sdata ()
8888 _sdata .write (self .sdata_path , overwrite = True )
8989
90- allowed_labels = ["seg_all_nucleus" , "seg_all_cytosol" ]
91- for key in _sdata .labels :
92- if key in allowed_labels :
93- segmentation_object = _sdata .labels [key ]
94- if not hasattr (segmentation_object .attrs , "cell_ids" ):
95- segmentation_object = spLabels2DModel ().convert (segmentation_object , classes = None )
96-
9790 return _sdata
9891
9992 def get_sdata (self ) -> SpatialData :
@@ -194,49 +187,47 @@ def _write_image_sdata(
194187 # check if the image is already a multi-scale image
195188 if isinstance (image , xarray .DataTree ):
196189 # if so only validate the model since this means we are getting the image from a spatialdata object already
197- # image = Image2DModel.validate(image)
198- # this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
190+ # fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
191+ Image2DModel (). validate ( image )
199192 if scale_factors is not None :
200193 Warning ("Scale factors are ignored when passing a multi-scale image." )
201- image = image .scale0 .image
202-
203- if scale_factors is None :
204- scale_factors = [2 , 4 , 8 ]
205- if scale_factors is None :
206- scale_factors = [2 , 4 , 8 ]
207-
208- if isinstance (image , xarray .DataArray ):
209- # if so first validate the model since this means we are getting the image from a spatialdata object already
210- # then apply the scales transform
211- # image = Image2DModel.validate(image)
212- # this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
213-
214- if channel_names is not None :
215- Warning (
216- "Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
194+ else :
195+ if scale_factors is None :
196+ scale_factors = [2 , 4 , 8 ]
197+ if scale_factors is None :
198+ scale_factors = [2 , 4 , 8 ]
199+
200+ if isinstance (image , xarray .DataArray ):
201+ # if so first validate the model since this means we are getting the image from a spatialdata object already
202+ # fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
203+ Image2DModel ().validate (image )
204+
205+ if channel_names is not None :
206+ Warning (
207+ "Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
208+ )
209+
210+ image = Image2DModel .parse (
211+ image ,
212+ scale_factors = scale_factors ,
213+ rgb = False ,
217214 )
218215
219- image = Image2DModel .parse (
220- image ,
221- scale_factors = scale_factors ,
222- rgb = False ,
223- )
224-
225- else :
226- if channel_names is None :
227- channel_names = [f"channel_{ i } " for i in range (image .shape [0 ])]
228-
229- # transform to spatialdata image model
230- transform_original = Identity ()
231- image = Image2DModel .parse (
232- image ,
233- dims = ["c" , "y" , "x" ],
234- chunks = chunks ,
235- c_coords = channel_names ,
236- scale_factors = scale_factors ,
237- transformations = {"global" : transform_original },
238- rgb = False ,
239- )
216+ else :
217+ if channel_names is None :
218+ channel_names = [f"channel_{ i } " for i in range (image .shape [0 ])]
219+
220+ # transform to spatialdata image model
221+ transform_original = Identity ()
222+ image = Image2DModel .parse (
223+ image ,
224+ dims = ["c" , "y" , "x" ],
225+ chunks = chunks ,
226+ c_coords = channel_names ,
227+ scale_factors = scale_factors ,
228+ transformations = {"global" : transform_original },
229+ rgb = False ,
230+ )
240231
241232 if overwrite :
242233 self ._force_delete_object (_sdata , image_name , "images" )
@@ -249,7 +240,7 @@ def _write_image_sdata(
249240
250241 def _write_segmentation_object_sdata (
251242 self ,
252- segmentation_object : spLabels2DModel ,
243+ segmentation_object : Labels2DModel ,
253244 segmentation_label : str ,
254245 classes : set [str ] | None = None ,
255246 overwrite : bool = False ,
@@ -264,9 +255,6 @@ def _write_segmentation_object_sdata(
264255 """
265256 _sdata = self ._read_sdata ()
266257
267- if not hasattr (segmentation_object .attrs , "cell_ids" ):
268- segmentation_object = spLabels2DModel ().convert (segmentation_object , classes = classes )
269-
270258 if overwrite :
271259 self ._force_delete_object (_sdata , segmentation_label , "labels" )
272260
@@ -294,7 +282,7 @@ def _write_segmentation_sdata(
294282 overwrite: Whether to overwrite existing data
295283 """
296284 transform_original = Identity ()
297- mask = spLabels2DModel .parse (
285+ mask = Labels2DModel .parse (
298286 segmentation ,
299287 dims = ["y" , "x" ],
300288 transformations = {"global" : transform_original },
@@ -324,6 +312,48 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over
324312
325313 self .log (f"Points { points_name } written to sdata object." )
326314
315+ def _write_table_sdata (
316+ self , adata : AnnData , table_name : str , segmentation_mask_name : str , overwrite : bool = False
317+ ) -> None :
318+ """Write anndata to SpatialData.
319+
320+ Args:
321+ adata: AnnData object to write
322+ table_name: Name for the table object under which it should be saved
323+ segmentation_mask_name: Name of the segmentation mask that this table annotates
324+ overwrite: Whether to overwrite existing data
325+
326+ Returns:
327+ None (writes to sdata object)
328+ """
329+ _sdata = self ._read_sdata ()
330+
331+ assert isinstance (adata , AnnData ), "Input data must be an AnnData object."
332+ assert segmentation_mask_name in _sdata .labels , "Segmentation mask not found in sdata object."
333+
334+ # get obs and obs_indices
335+ obs = adata .obs
336+ obs_indices = adata .obs .index .astype (int ) # need to ensure int subtype to be able to annotate seg masks
337+
338+ # sanity checking
339+ assert len (obs_indices ) == len (set (obs_indices )), "Instance IDs are not unique."
340+ cell_ids_mask = set (_sdata [f"{ self .centers_name } _{ segmentation_mask_name } " ].index .values .compute ())
341+ assert set (obs_indices ).issubset (cell_ids_mask ), "Instance IDs do not match segmentation mask cell IDs."
342+
343+ obs ["instance_id" ] = obs_indices
344+ obs ["region" ] = segmentation_mask_name
345+ obs ["region" ] = obs ["region" ].astype ("category" )
346+
347+ adata .obs = obs
348+ table = TableModel .parse (
349+ adata ,
350+ region = [segmentation_mask_name ],
351+ region_key = "region" ,
352+ instance_key = "instance_id" ,
353+ )
354+
355+ self ._write_table_object_sdata (table , table_name , overwrite = overwrite )
356+
327357 def _write_table_object_sdata (self , table : TableModel , table_name : str , overwrite : bool = False ) -> None :
328358 """Write table object to SpatialData.
329359
@@ -358,7 +388,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod
358388 if segmentation_label not in sdata .labels :
359389 raise ValueError (f"Segmentation { segmentation_label } not found in sdata object." )
360390
361- centers = calculate_centroids (sdata .labels [segmentation_label ])
391+ mask = sdata .labels [segmentation_label ]
392+ if isinstance (mask , xarray .DataTree ):
393+ mask = mask .scale0 .image
394+ centers = calculate_centroids (mask )
362395 return centers
363396
364397 def _add_centers (self , segmentation_label : str , overwrite : bool = False ) -> None :
@@ -370,7 +403,8 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
370403 """
371404 _sdata = self ._read_sdata ()
372405 centroids_object = self ._get_centers (_sdata , segmentation_label )
373- self ._write_points_object_sdata (centroids_object , self .centers_name , overwrite = overwrite )
406+ centers_name = f"{ self .centers_name } _{ segmentation_label } "
407+ self ._write_points_object_sdata (centroids_object , centers_name , overwrite = overwrite )
374408
375409 ## load elements from sdata to a memory mapped array
376410 def _load_input_image_to_memmap (
0 commit comments