99import xarray
1010from alphabase .io import tempmmap
1111from spatialdata import SpatialData
12- from spatialdata .models import Image2DModel , PointsModel , TableModel
12+ from spatialdata .models import Image2DModel , Labels2DModel , PointsModel , TableModel
1313from spatialdata .transformations .transformations import Identity
1414
1515from scportrait .pipeline ._base import Logable
16- from scportrait .pipeline ._utils .spatialdata_classes import spLabels2DModel
1716from scportrait .pipeline ._utils .spatialdata_helper import (
1817 calculate_centroids ,
1918 get_chunk_size ,
@@ -87,13 +86,6 @@ def _read_sdata(self) -> SpatialData:
8786 _sdata = self ._create_empty_sdata ()
8887 _sdata .write (self .sdata_path , overwrite = True )
8988
90- allowed_labels = ["seg_all_nucleus" , "seg_all_cytosol" ]
91- for key in _sdata .labels :
92- if key in allowed_labels :
93- segmentation_object = _sdata .labels [key ]
94- if not hasattr (segmentation_object .attrs , "cell_ids" ):
95- segmentation_object = spLabels2DModel ().convert (segmentation_object , classes = None )
96-
9789 return _sdata
9890
9991 def get_sdata (self ) -> SpatialData :
@@ -249,7 +241,7 @@ def _write_image_sdata(
249241
250242 def _write_segmentation_object_sdata (
251243 self ,
252- segmentation_object : spLabels2DModel ,
244+ segmentation_object : Labels2DModel ,
253245 segmentation_label : str ,
254246 classes : set [str ] | None = None ,
255247 overwrite : bool = False ,
@@ -264,9 +256,6 @@ def _write_segmentation_object_sdata(
264256 """
265257 _sdata = self ._read_sdata ()
266258
267- if not hasattr (segmentation_object .attrs , "cell_ids" ):
268- segmentation_object = spLabels2DModel ().convert (segmentation_object , classes = classes )
269-
270259 if overwrite :
271260 self ._force_delete_object (_sdata , segmentation_label , "labels" )
272261
@@ -294,7 +283,7 @@ def _write_segmentation_sdata(
294283 overwrite: Whether to overwrite existing data
295284 """
296285 transform_original = Identity ()
297- mask = spLabels2DModel .parse (
286+ mask = Labels2DModel .parse (
298287 segmentation ,
299288 dims = ["y" , "x" ],
300289 transformations = {"global" : transform_original },
0 commit comments