Skip to content

Commit 206c5d2

Browse files
committed
deprecate _force_delete_object in favour of new spdata function
1 parent ddf0cde commit 206c5d2

File tree

1 file changed

+11
-58
lines changed

1 file changed

+11
-58
lines changed

src/scportrait/pipeline/_utils/sdata_io.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
calculate_centroids,
1919
get_chunk_size,
2020
)
21+
from scportrait.spdata.write._helper import add_element_sdata
2122

2223
ChunkSize2D: TypeAlias = tuple[int, int]
2324
ChunkSize3D: TypeAlias = tuple[int, int, int]
@@ -97,21 +98,6 @@ def get_sdata(self) -> SpatialData:
9798
"""
9899
return self._read_sdata()
99100

100-
def _force_delete_object(self, sdata: SpatialData, name: str, type: ObjectType) -> None:
101-
"""Force delete an object from the SpatialData object and directory.
102-
103-
Args:
104-
sdata: SpatialData object
105-
name: Name of object to delete
106-
type: Type of object ("images", "labels", "points", "tables")
107-
"""
108-
if name in sdata:
109-
del sdata[name]
110-
111-
path = os.path.join(self.sdata_path, type, name)
112-
if os.path.exists(path):
113-
shutil.rmtree(path, ignore_errors=True)
114-
115101
def _check_sdata_status(self, return_sdata: bool = False) -> SpatialData | None:
116102
"""Check status of SpatialData objects.
117103
@@ -229,12 +215,7 @@ def _write_image_sdata(
229215
rgb=False,
230216
)
231217

232-
if overwrite:
233-
self._force_delete_object(_sdata, image_name, "images")
234-
235-
_sdata.images[image_name] = image
236-
_sdata.write_element(image_name, overwrite=True)
237-
218+
add_element_sdata(_sdata, image, image_name, overwrite=overwrite)
238219
self.log(f"Image {image_name} written to sdata object.")
239220
self._check_sdata_status()
240221

@@ -253,13 +234,7 @@ def _write_segmentation_object_sdata(
253234
overwrite: Whether to overwrite existing data
254235
"""
255236
_sdata = self._read_sdata()
256-
257-
if overwrite:
258-
self._force_delete_object(_sdata, segmentation_label, "labels")
259-
260-
_sdata.labels[segmentation_label] = segmentation_object
261-
_sdata.write_element(segmentation_label, overwrite=True)
262-
237+
add_element_sdata(_sdata, segmentation_object, segmentation_label, overwrite=overwrite)
263238
self.log(f"Segmentation {segmentation_label} written to sdata object.")
264239
self._check_sdata_status()
265240

@@ -281,10 +256,7 @@ def _write_segmentation_sdata(
281256
"""
282257
transform_original = Identity()
283258
mask = Labels2DModel.parse(
284-
segmentation,
285-
dims=["y", "x"],
286-
transformations={"global": transform_original},
287-
chunks=chunks,
259+
segmentation, dims=["y", "x"], transformations={"global": transform_original}, chunks=chunks
288260
)
289261

290262
if not get_chunk_size(mask) == chunks:
@@ -301,14 +273,9 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over
301273
overwrite: Whether to overwrite existing data
302274
"""
303275
_sdata = self._read_sdata()
304-
305-
if overwrite:
306-
self._force_delete_object(_sdata, points_name, "points")
307-
308-
_sdata.points[points_name] = points
309-
_sdata.write_element(points_name, overwrite=True)
310-
276+
add_element_sdata(_sdata, points, points_name, overwrite=overwrite)
311277
self.log(f"Points {points_name} written to sdata object.")
278+
self._check_sdata_status()
312279

313280
def _write_table_sdata(
314281
self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False
@@ -344,10 +311,7 @@ def _write_table_sdata(
344311

345312
adata.obs = obs
346313
table = TableModel.parse(
347-
adata,
348-
region=[segmentation_mask_name],
349-
region_key="region",
350-
instance_key="instance_id",
314+
adata, region=[segmentation_mask_name], region_key="region", instance_key="instance_id"
351315
)
352316

353317
self._write_table_object_sdata(table, table_name, overwrite=overwrite)
@@ -361,14 +325,9 @@ def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrit
361325
overwrite: Whether to overwrite existing data
362326
"""
363327
_sdata = self._read_sdata()
364-
365-
if overwrite:
366-
self._force_delete_object(_sdata, table_name, "tables")
367-
368-
_sdata.tables[table_name] = table
369-
_sdata.write_element(table_name, overwrite=True)
370-
328+
add_element_sdata(_sdata, table, table_name, overwrite=overwrite)
371329
self.log(f"Table {table_name} written to sdata object.")
330+
self._check_sdata_status()
372331

373332
def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsModel:
374333
"""Get cell centers from segmentation.
@@ -433,11 +392,7 @@ def _load_input_image_to_memmap(
433392
shape = image.shape
434393

435394
# initialize empty memory mapped arrays to store the data
436-
path_input_image = tempmmap.create_empty_mmap(
437-
shape=shape,
438-
dtype=image.dtype,
439-
tmp_dir_abs_path=tmp_dir_abs_path,
440-
)
395+
path_input_image = tempmmap.create_empty_mmap(shape=shape, dtype=image.dtype, tmp_dir_abs_path=tmp_dir_abs_path)
441396

442397
input_image_mmap = tempmmap.mmap_array_from_path(path_input_image)
443398

@@ -521,9 +476,7 @@ def _load_seg_to_memmap(
521476

522477
# initialize empty memory mapped arrays to store the data
523478
path_seg_masks = tempmmap.create_empty_mmap(
524-
shape=shape,
525-
dtype=seg_objects[0].data.dtype,
526-
tmp_dir_abs_path=tmp_dir_abs_path,
479+
shape=shape, dtype=seg_objects[0].data.dtype, tmp_dir_abs_path=tmp_dir_abs_path
527480
)
528481

529482
seg_masks = tempmmap.mmap_array_from_path(path_seg_masks)

0 commit comments

Comments
 (0)