Skip to content

Commit 664353b

Browse files
committed
simplify _from_sdata method to rely on methods implement in spatialdata
1 parent 907bf19 commit 664353b

File tree

1 file changed

+32
-164
lines changed

1 file changed

+32
-164
lines changed

src/scportrait/pipeline/project.py

Lines changed: 32 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
rechunk_image,
4343
remap_region_annotation_table,
4444
)
45+
from scportrait.spdata.write._helper import _get_shape, _make_key_lookup
4546

4647
if TYPE_CHECKING:
4748
from collections.abc import Callable
@@ -875,186 +876,53 @@ def load_input_from_sdata(
875876

876877
# read input sdata object
877878
sdata_input = SpatialData.read(sdata_path)
878-
if keep_all:
879-
shutil.rmtree(self.sdata_path, ignore_errors=True) # remove old sdata object
880-
sdata_input.write(self.sdata_path, overwrite=True)
881-
del sdata_input
882-
sdata_input = self.filehandler.get_sdata()
883-
884-
self.get_project_status()
885-
886-
# get input image and write it to the final sdata object
887-
image = sdata_input.images[input_image_name]
888-
self.log(f"Adding image {input_image_name} to sdata object as 'input_image'.")
889-
890-
if isinstance(image, xarray.DataTree):
891-
image_c, image_x, image_y = image.scale0.image.shape
892-
893-
# ensure chunking is correct
894-
if rechunk:
895-
for scale in image:
896-
self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D)
879+
all_elements = [x.split("/")[1] for x in sdata_input.elements_paths_in_memory()]
897880

898-
# get channel names
899-
channel_names = image.scale0.image.c.values
881+
dict_elems = {self.DEFAULT_INPUT_IMAGE_NAME: sdata_input[input_image_name]}
900882

901-
elif isinstance(image, xarray.DataArray):
902-
image_c, image_x, image_y = image.shape
903-
904-
# ensure chunking is correct
905-
if rechunk:
906-
self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D)
907-
908-
channel_names = image.c.values
909-
910-
# Reset all transformations
911-
if image.attrs.get("transform"):
912-
self.log("Image contained transformations which which were removed.")
913-
image.attrs["transform"] = None
914-
915-
# check coordinate system of input image
916-
### PLACEHOLDER
917-
918-
# check channel names
919-
self.log(
920-
f"Found the following channel names in the input image and saving in the spatialdata object: {channel_names}"
921-
)
922-
923-
self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=channel_names)
924-
925-
# check if a nucleus segmentation exists and if so add it to the sdata object
926883
if nucleus_segmentation_name is not None:
927-
mask = sdata_input.labels[nucleus_segmentation_name]
928-
self.log(
929-
f"Adding nucleus segmentation mask '{nucleus_segmentation_name}' to sdata object as '{self.nuc_seg_name}'."
930-
)
931-
932-
# if mask is multi-scale ensure we only use the scale 0
933-
if isinstance(mask, xarray.DataTree):
934-
mask = mask["scale0"].image
935-
936-
# ensure that loaded masks are at the same scale as the input image
937-
mask_x, mask_y = mask.shape
938-
assert (mask_x == image_x) and (
939-
mask_y == image_y
940-
), "Nucleus segmentation mask does not match input image size."
941-
942-
if rechunk:
943-
self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct
944-
945-
self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name)
946-
self.log(
947-
f"Calculating centers for nucleus segmentation mask {self.nuc_seg_name} and adding to spatialdata object."
948-
)
949-
self.filehandler._add_centers(segmentation_label=self.nuc_seg_name)
884+
dict_elems[self.nuc_seg_name] = sdata_input[nucleus_segmentation_name]
885+
if remove_duplicates:
886+
all_elements.remove(nucleus_segmentation_name)
950887

951-
# check if a cytosol segmentation exists and if so add it to the sdata object
952888
if cytosol_segmentation_name is not None:
953-
mask = sdata_input.labels[cytosol_segmentation_name]
954-
self.log(
955-
f"Adding cytosol segmentation mask '{cytosol_segmentation_name}' to sdata object as '{self.cyto_seg_name}'."
956-
)
957-
958-
# if mask is multi-scale ensure we only use the scale 0
959-
if isinstance(mask, xarray.DataTree):
960-
mask = mask["scale0"].image
889+
dict_elems[self.cyto_seg_name] = sdata_input[cytosol_segmentation_name]
890+
if remove_duplicates:
891+
all_elements.remove(cytosol_segmentation_name)
961892

962-
# ensure that loaded masks are at the same scale as the input image
963-
mask_x, mask_y = mask.shape
964-
assert (mask_x == image_x) and (
965-
mask_y == image_y
966-
), "Nucleus segmentation mask does not match input image size."
893+
if keep_all:
894+
shutil.rmtree(self.sdata_path, ignore_errors=True)
895+
for elem in all_elements:
896+
dict_elems[elem] = sdata_input[elem]
967897

968-
if rechunk:
969-
self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct
970-
971-
self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name)
972-
self.log(
973-
f"Calculating centers for cytosol segmentation mask {self.nuc_seg_name} and adding to spatialdata object."
974-
)
975-
self.filehandler._add_centers(segmentation_label=self.cyto_seg_name)
898+
sdata = SpatialData.from_elements_dict(dict_elems)
899+
sdata.write(self.sdata_path, overwrite=True)
976900

901+
# update project status
977902
self.get_project_status()
903+
_, x, y = _get_shape(sdata[self.DEFAULT_INPUT_IMAGE_NAME])
978904

979-
# ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements
980-
# requirements are:
981-
# 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids
982-
# if self.nuc_seg_status and self.cyto_seg_status:
983-
# THIS NEEDS TO BE IMPLEMENTED HERE
984-
985-
# 2. the nucleus segmentation ids and the cytosol segmentation ids need to match
986-
# THIS NEEDS TO BE IMPLEMENTED HERE
987-
988-
# check if there are any annotations that match the nucleus/cytosol segmentations
989-
if self.nuc_seg_status or self.cyto_seg_status:
990-
region_annotation = generate_region_annotation_lookuptable(self.sdata)
991-
992-
if self.nuc_seg_status:
993-
region_name = self.nuc_seg_name
994-
995-
# add existing nucleus annotations if available
996-
if nucleus_segmentation_name in region_annotation.keys():
997-
for x in region_annotation[nucleus_segmentation_name]:
998-
table_name, table = x
999-
1000-
new_table_name = f"annot_{region_name}_{table_name}"
905+
self.overwrite = original_overwrite
1001906

1002-
table = remap_region_annotation_table(table, region_name=region_name)
907+
if self.nuc_seg_status:
908+
# check input size
909+
_, x_mask, y_mask = _get_shape(sdata[self.nuc_seg_name])
910+
assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size."
1003911

1004-
self.filehandler._write_table_object_sdata(table, new_table_name)
1005-
self.log(
1006-
f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}."
1007-
)
1008-
1009-
if keep_all and remove_duplicates:
1010-
self.log(
1011-
f"Deleting original annotation {table_name} for nucleus segmentation {nucleus_segmentation_name} from sdata object to prevent information duplication."
1012-
)
1013-
self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables")
1014-
else:
1015-
self.log(f"No region annotation found for the nucleus segmentation {nucleus_segmentation_name}.")
1016-
1017-
# add cytosol segmentations if available
1018-
if self.cyto_seg_status:
1019-
if cytosol_segmentation_name in region_annotation.keys():
1020-
for x in region_annotation[cytosol_segmentation_name]:
1021-
table_name, table = x
1022-
region_name = self.cyto_seg_name
1023-
new_table_name = f"annot_{region_name}_{table_name}"
1024-
1025-
table = remap_region_annotation_table(table, region_name=region_name)
1026-
self.filehandler._write_table_object_sdata(table, new_table_name)
1027-
1028-
self.log(
1029-
f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}."
1030-
)
1031-
1032-
if keep_all and remove_duplicates:
1033-
self.log(
1034-
f"Deleting original annotation {table_name} for cytosol segmentation {cytosol_segmentation_name} from sdata object to prevent information duplication."
1035-
)
1036-
self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables")
1037-
else:
1038-
self.log(f"No region annotation found for the cytosol segmentation {cytosol_segmentation_name}.")
912+
self.filehandler._add_centers(segmentation_label=self.nuc_seg_name)
913+
if self.cyto_seg_status:
914+
# check input size
915+
_, x_mask, y_mask = _get_shape(sdata[self.cyto_seg_name])
916+
assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size."
1039917

1040-
if keep_all and remove_duplicates:
1041-
# remove input image
1042-
self.log(f"Deleting input image '{input_image_name}' from sdata object to prevent information duplication.")
1043-
self.filehandler._force_delete_object(self.sdata, name=input_image_name, type="images")
918+
self.filehandler._add_centers(segmentation_label=self.cyto_seg_name)
1044919

1045-
if self.nuc_seg_status:
1046-
self.log(
1047-
f"Deleting original nucleus segmentation mask '{nucleus_segmentation_name}' from sdata object to prevent information duplication."
1048-
)
1049-
self.filehandler._force_delete_object(self.sdata, name=nucleus_segmentation_name, type="labels")
1050-
if self.cyto_seg_status:
1051-
self.log(
1052-
f"Deleting original cytosol segmentation mask '{cytosol_segmentation_name}' from sdata object to prevent information duplication."
1053-
)
1054-
self.filehandler._force_delete_object(self.sdata, name=cytosol_segmentation_name, type="labels")
920+
if self.nuc_seg_status and self.cyto_seg_status:
921+
ids_nuc = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.nuc_seg_name}"].index.values)
922+
ids_cyto = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.cyto_seg_name}"].index.values)
923+
assert ids_nuc == ids_cyto, "Nucleus and cytosol segmentation masks do not match."
1055924

1056925
self.get_project_status()
1057-
self.overwrite = original_overwrite # reset to original value
1058926

1059927
#### Functions to perform processing ####
1060928

0 commit comments

Comments
 (0)