1818 calculate_centroids ,
1919 get_chunk_size ,
2020)
21+ from scportrait .spdata .write ._helper import add_element_sdata
2122
2223ChunkSize2D : TypeAlias = tuple [int , int ]
2324ChunkSize3D : TypeAlias = tuple [int , int , int ]
@@ -97,21 +98,6 @@ def get_sdata(self) -> SpatialData:
9798 """
9899 return self ._read_sdata ()
99100
100- def _force_delete_object (self , sdata : SpatialData , name : str , type : ObjectType ) -> None :
101- """Force delete an object from the SpatialData object and directory.
102-
103- Args:
104- sdata: SpatialData object
105- name: Name of object to delete
106- type: Type of object ("images", "labels", "points", "tables")
107- """
108- if name in sdata :
109- del sdata [name ]
110-
111- path = os .path .join (self .sdata_path , type , name )
112- if os .path .exists (path ):
113- shutil .rmtree (path , ignore_errors = True )
114-
115101 def _check_sdata_status (self , return_sdata : bool = False ) -> SpatialData | None :
116102 """Check status of SpatialData objects.
117103
@@ -229,12 +215,7 @@ def _write_image_sdata(
229215 rgb = False ,
230216 )
231217
232- if overwrite :
233- self ._force_delete_object (_sdata , image_name , "images" )
234-
235- _sdata .images [image_name ] = image
236- _sdata .write_element (image_name , overwrite = True )
237-
218+ add_element_sdata (_sdata , image , image_name , overwrite = overwrite )
238219 self .log (f"Image { image_name } written to sdata object." )
239220 self ._check_sdata_status ()
240221
@@ -253,13 +234,7 @@ def _write_segmentation_object_sdata(
253234 overwrite: Whether to overwrite existing data
254235 """
255236 _sdata = self ._read_sdata ()
256-
257- if overwrite :
258- self ._force_delete_object (_sdata , segmentation_label , "labels" )
259-
260- _sdata .labels [segmentation_label ] = segmentation_object
261- _sdata .write_element (segmentation_label , overwrite = True )
262-
237+ add_element_sdata (_sdata , segmentation_object , segmentation_label , overwrite = overwrite )
263238 self .log (f"Segmentation { segmentation_label } written to sdata object." )
264239 self ._check_sdata_status ()
265240
@@ -281,10 +256,7 @@ def _write_segmentation_sdata(
281256 """
282257 transform_original = Identity ()
283258 mask = Labels2DModel .parse (
284- segmentation ,
285- dims = ["y" , "x" ],
286- transformations = {"global" : transform_original },
287- chunks = chunks ,
259+ segmentation , dims = ["y" , "x" ], transformations = {"global" : transform_original }, chunks = chunks
288260 )
289261
290262 if not get_chunk_size (mask ) == chunks :
@@ -301,14 +273,9 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over
301273 overwrite: Whether to overwrite existing data
302274 """
303275 _sdata = self ._read_sdata ()
304-
305- if overwrite :
306- self ._force_delete_object (_sdata , points_name , "points" )
307-
308- _sdata .points [points_name ] = points
309- _sdata .write_element (points_name , overwrite = True )
310-
276+ add_element_sdata (_sdata , points , points_name , overwrite = overwrite )
311277 self .log (f"Points { points_name } written to sdata object." )
278+ self ._check_sdata_status ()
312279
313280 def _write_table_sdata (
314281 self , adata : AnnData , table_name : str , segmentation_mask_name : str , overwrite : bool = False
@@ -344,10 +311,7 @@ def _write_table_sdata(
344311
345312 adata .obs = obs
346313 table = TableModel .parse (
347- adata ,
348- region = [segmentation_mask_name ],
349- region_key = "region" ,
350- instance_key = "instance_id" ,
314+ adata , region = [segmentation_mask_name ], region_key = "region" , instance_key = "instance_id"
351315 )
352316
353317 self ._write_table_object_sdata (table , table_name , overwrite = overwrite )
@@ -361,14 +325,9 @@ def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrit
361325 overwrite: Whether to overwrite existing data
362326 """
363327 _sdata = self ._read_sdata ()
364-
365- if overwrite :
366- self ._force_delete_object (_sdata , table_name , "tables" )
367-
368- _sdata .tables [table_name ] = table
369- _sdata .write_element (table_name , overwrite = True )
370-
328+ add_element_sdata (_sdata , table , table_name , overwrite = overwrite )
371329 self .log (f"Table { table_name } written to sdata object." )
330+ self ._check_sdata_status ()
372331
373332 def _get_centers (self , sdata : SpatialData , segmentation_label : str ) -> PointsModel :
374333 """Get cell centers from segmentation.
@@ -433,11 +392,7 @@ def _load_input_image_to_memmap(
433392 shape = image .shape
434393
435394 # initialize empty memory mapped arrays to store the data
436- path_input_image = tempmmap .create_empty_mmap (
437- shape = shape ,
438- dtype = image .dtype ,
439- tmp_dir_abs_path = tmp_dir_abs_path ,
440- )
395+ path_input_image = tempmmap .create_empty_mmap (shape = shape , dtype = image .dtype , tmp_dir_abs_path = tmp_dir_abs_path )
441396
442397 input_image_mmap = tempmmap .mmap_array_from_path (path_input_image )
443398
@@ -521,9 +476,7 @@ def _load_seg_to_memmap(
521476
522477 # initialize empty memory mapped arrays to store the data
523478 path_seg_masks = tempmmap .create_empty_mmap (
524- shape = shape ,
525- dtype = seg_objects [0 ].data .dtype ,
526- tmp_dir_abs_path = tmp_dir_abs_path ,
479+ shape = shape , dtype = seg_objects [0 ].data .dtype , tmp_dir_abs_path = tmp_dir_abs_path
527480 )
528481
529482 seg_masks = tempmmap .mmap_array_from_path (path_seg_masks )
0 commit comments