|
42 | 42 | rechunk_image, |
43 | 43 | remap_region_annotation_table, |
44 | 44 | ) |
| 45 | +from scportrait.spdata.write._helper import _get_shape, _make_key_lookup |
45 | 46 |
|
46 | 47 | if TYPE_CHECKING: |
47 | 48 | from collections.abc import Callable |
@@ -875,186 +876,53 @@ def load_input_from_sdata( |
875 | 876 |
|
876 | 877 | # read input sdata object |
877 | 878 | sdata_input = SpatialData.read(sdata_path) |
878 | | - if keep_all: |
879 | | - shutil.rmtree(self.sdata_path, ignore_errors=True) # remove old sdata object |
880 | | - sdata_input.write(self.sdata_path, overwrite=True) |
881 | | - del sdata_input |
882 | | - sdata_input = self.filehandler.get_sdata() |
883 | | - |
884 | | - self.get_project_status() |
885 | | - |
886 | | - # get input image and write it to the final sdata object |
887 | | - image = sdata_input.images[input_image_name] |
888 | | - self.log(f"Adding image {input_image_name} to sdata object as 'input_image'.") |
889 | | - |
890 | | - if isinstance(image, xarray.DataTree): |
891 | | - image_c, image_x, image_y = image.scale0.image.shape |
892 | | - |
893 | | - # ensure chunking is correct |
894 | | - if rechunk: |
895 | | - for scale in image: |
896 | | - self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) |
| 879 | + all_elements = [x.split("/")[1] for x in sdata_input.elements_paths_in_memory()] |
897 | 880 |
|
898 | | - # get channel names |
899 | | - channel_names = image.scale0.image.c.values |
| 881 | + dict_elems = {self.DEFAULT_INPUT_IMAGE_NAME: sdata_input[input_image_name]} |
900 | 882 |
|
901 | | - elif isinstance(image, xarray.DataArray): |
902 | | - image_c, image_x, image_y = image.shape |
903 | | - |
904 | | - # ensure chunking is correct |
905 | | - if rechunk: |
906 | | - self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) |
907 | | - |
908 | | - channel_names = image.c.values |
909 | | - |
910 | | - # Reset all transformations |
911 | | - if image.attrs.get("transform"): |
912 | | - self.log("Image contained transformations which which were removed.") |
913 | | - image.attrs["transform"] = None |
914 | | - |
915 | | - # check coordinate system of input image |
916 | | - ### PLACEHOLDER |
917 | | - |
918 | | - # check channel names |
919 | | - self.log( |
920 | | - f"Found the following channel names in the input image and saving in the spatialdata object: {channel_names}" |
921 | | - ) |
922 | | - |
923 | | - self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=channel_names) |
924 | | - |
925 | | - # check if a nucleus segmentation exists and if so add it to the sdata object |
926 | 883 | if nucleus_segmentation_name is not None: |
927 | | - mask = sdata_input.labels[nucleus_segmentation_name] |
928 | | - self.log( |
929 | | - f"Adding nucleus segmentation mask '{nucleus_segmentation_name}' to sdata object as '{self.nuc_seg_name}'." |
930 | | - ) |
931 | | - |
932 | | - # if mask is multi-scale ensure we only use the scale 0 |
933 | | - if isinstance(mask, xarray.DataTree): |
934 | | - mask = mask["scale0"].image |
935 | | - |
936 | | - # ensure that loaded masks are at the same scale as the input image |
937 | | - mask_x, mask_y = mask.shape |
938 | | - assert (mask_x == image_x) and ( |
939 | | - mask_y == image_y |
940 | | - ), "Nucleus segmentation mask does not match input image size." |
941 | | - |
942 | | - if rechunk: |
943 | | - self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct |
944 | | - |
945 | | - self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name) |
946 | | - self.log( |
947 | | - f"Calculating centers for nucleus segmentation mask {self.nuc_seg_name} and adding to spatialdata object." |
948 | | - ) |
949 | | - self.filehandler._add_centers(segmentation_label=self.nuc_seg_name) |
| 884 | + dict_elems[self.nuc_seg_name] = sdata_input[nucleus_segmentation_name] |
| 885 | + if remove_duplicates: |
| 886 | + all_elements.remove(nucleus_segmentation_name) |
950 | 887 |
|
951 | | - # check if a cytosol segmentation exists and if so add it to the sdata object |
952 | 888 | if cytosol_segmentation_name is not None: |
953 | | - mask = sdata_input.labels[cytosol_segmentation_name] |
954 | | - self.log( |
955 | | - f"Adding cytosol segmentation mask '{cytosol_segmentation_name}' to sdata object as '{self.cyto_seg_name}'." |
956 | | - ) |
957 | | - |
958 | | - # if mask is multi-scale ensure we only use the scale 0 |
959 | | - if isinstance(mask, xarray.DataTree): |
960 | | - mask = mask["scale0"].image |
| 889 | + dict_elems[self.cyto_seg_name] = sdata_input[cytosol_segmentation_name] |
| 890 | + if remove_duplicates: |
| 891 | + all_elements.remove(cytosol_segmentation_name) |
961 | 892 |
|
962 | | - # ensure that loaded masks are at the same scale as the input image |
963 | | - mask_x, mask_y = mask.shape |
964 | | - assert (mask_x == image_x) and ( |
965 | | - mask_y == image_y |
966 | | - ), "Nucleus segmentation mask does not match input image size." |
| 893 | + if keep_all: |
| 894 | + shutil.rmtree(self.sdata_path, ignore_errors=True) |
| 895 | + for elem in all_elements: |
| 896 | + dict_elems[elem] = sdata_input[elem] |
967 | 897 |
|
968 | | - if rechunk: |
969 | | - self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct |
970 | | - |
971 | | - self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name) |
972 | | - self.log( |
973 | | - f"Calculating centers for cytosol segmentation mask {self.nuc_seg_name} and adding to spatialdata object." |
974 | | - ) |
975 | | - self.filehandler._add_centers(segmentation_label=self.cyto_seg_name) |
| 898 | + sdata = SpatialData.from_elements_dict(dict_elems) |
| 899 | + sdata.write(self.sdata_path, overwrite=True) |
976 | 900 |
|
| 901 | + # update project status |
977 | 902 | self.get_project_status() |
| 903 | + _, x, y = _get_shape(sdata[self.DEFAULT_INPUT_IMAGE_NAME]) |
978 | 904 |
|
979 | | - # ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements |
980 | | - # requirements are: |
981 | | - # 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids |
982 | | - # if self.nuc_seg_status and self.cyto_seg_status: |
983 | | - # THIS NEEDS TO BE IMPLEMENTED HERE |
984 | | - |
985 | | - # 2. the nucleus segmentation ids and the cytosol segmentation ids need to match |
986 | | - # THIS NEEDS TO BE IMPLEMENTED HERE |
987 | | - |
988 | | - # check if there are any annotations that match the nucleus/cytosol segmentations |
989 | | - if self.nuc_seg_status or self.cyto_seg_status: |
990 | | - region_annotation = generate_region_annotation_lookuptable(self.sdata) |
991 | | - |
992 | | - if self.nuc_seg_status: |
993 | | - region_name = self.nuc_seg_name |
994 | | - |
995 | | - # add existing nucleus annotations if available |
996 | | - if nucleus_segmentation_name in region_annotation.keys(): |
997 | | - for x in region_annotation[nucleus_segmentation_name]: |
998 | | - table_name, table = x |
999 | | - |
1000 | | - new_table_name = f"annot_{region_name}_{table_name}" |
| 905 | + self.overwrite = original_overwrite |
1001 | 906 |
|
1002 | | - table = remap_region_annotation_table(table, region_name=region_name) |
| 907 | + if self.nuc_seg_status: |
| 908 | + # check input size |
| 909 | + _, x_mask, y_mask = _get_shape(sdata[self.nuc_seg_name]) |
| 910 | + assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size." |
1003 | 911 |
|
1004 | | - self.filehandler._write_table_object_sdata(table, new_table_name) |
1005 | | - self.log( |
1006 | | - f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}." |
1007 | | - ) |
1008 | | - |
1009 | | - if keep_all and remove_duplicates: |
1010 | | - self.log( |
1011 | | - f"Deleting original annotation {table_name} for nucleus segmentation {nucleus_segmentation_name} from sdata object to prevent information duplication." |
1012 | | - ) |
1013 | | - self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables") |
1014 | | - else: |
1015 | | - self.log(f"No region annotation found for the nucleus segmentation {nucleus_segmentation_name}.") |
1016 | | - |
1017 | | - # add cytosol segmentations if available |
1018 | | - if self.cyto_seg_status: |
1019 | | - if cytosol_segmentation_name in region_annotation.keys(): |
1020 | | - for x in region_annotation[cytosol_segmentation_name]: |
1021 | | - table_name, table = x |
1022 | | - region_name = self.cyto_seg_name |
1023 | | - new_table_name = f"annot_{region_name}_{table_name}" |
1024 | | - |
1025 | | - table = remap_region_annotation_table(table, region_name=region_name) |
1026 | | - self.filehandler._write_table_object_sdata(table, new_table_name) |
1027 | | - |
1028 | | - self.log( |
1029 | | - f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}." |
1030 | | - ) |
1031 | | - |
1032 | | - if keep_all and remove_duplicates: |
1033 | | - self.log( |
1034 | | - f"Deleting original annotation {table_name} for cytosol segmentation {cytosol_segmentation_name} from sdata object to prevent information duplication." |
1035 | | - ) |
1036 | | - self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables") |
1037 | | - else: |
1038 | | - self.log(f"No region annotation found for the cytosol segmentation {cytosol_segmentation_name}.") |
| 912 | + self.filehandler._add_centers(segmentation_label=self.nuc_seg_name) |
| 913 | + if self.cyto_seg_status: |
| 914 | + # check input size |
| 915 | + _, x_mask, y_mask = _get_shape(sdata[self.cyto_seg_name]) |
| 916 | + assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size." |
1039 | 917 |
|
1040 | | - if keep_all and remove_duplicates: |
1041 | | - # remove input image |
1042 | | - self.log(f"Deleting input image '{input_image_name}' from sdata object to prevent information duplication.") |
1043 | | - self.filehandler._force_delete_object(self.sdata, name=input_image_name, type="images") |
| 918 | + self.filehandler._add_centers(segmentation_label=self.cyto_seg_name) |
1044 | 919 |
|
1045 | | - if self.nuc_seg_status: |
1046 | | - self.log( |
1047 | | - f"Deleting original nucleus segmentation mask '{nucleus_segmentation_name}' from sdata object to prevent information duplication." |
1048 | | - ) |
1049 | | - self.filehandler._force_delete_object(self.sdata, name=nucleus_segmentation_name, type="labels") |
1050 | | - if self.cyto_seg_status: |
1051 | | - self.log( |
1052 | | - f"Deleting original cytosol segmentation mask '{cytosol_segmentation_name}' from sdata object to prevent information duplication." |
1053 | | - ) |
1054 | | - self.filehandler._force_delete_object(self.sdata, name=cytosol_segmentation_name, type="labels") |
| 920 | + if self.nuc_seg_status and self.cyto_seg_status: |
| 921 | + ids_nuc = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.nuc_seg_name}"].index.values) |
| 922 | + ids_cyto = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.cyto_seg_name}"].index.values) |
| 923 | + assert ids_nuc == ids_cyto, "Nucleus and cytosol segmentation masks do not match." |
1055 | 924 |
|
1056 | 925 | self.get_project_status() |
1057 | | - self.overwrite = original_overwrite # reset to original value |
1058 | 926 |
|
1059 | 927 | #### Functions to perform processing #### |
1060 | 928 |
|
|
0 commit comments