Skip to content

Commit f12a51b

Browse files
Merge pull request #335 from MannLabs/fix_bugs_xenium
Fix bugs xenium
2 parents 96a62bd + fe2cfd8 commit f12a51b

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

src/scportrait/pipeline/featurization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _detect_automatic_inference_device(self) -> str:
125125

126126
if torch.cuda.is_available():
127127
inference_device = "cuda"
128-
if torch.backends.mps.is_available():
128+
elif torch.backends.mps.is_available():
129129
inference_device = torch.device("mps")
130130
else:
131131
inference_device = "cpu"

src/scportrait/pipeline/project.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import shutil
1717
import tempfile
18+
import warnings
1819
from pathlib import PosixPath
1920
from typing import TYPE_CHECKING, Literal
2021

@@ -198,7 +199,7 @@ def __init__(
198199
if not os.path.isdir(self.project_location):
199200
os.makedirs(self.project_location)
200201
else:
201-
Warning("There is already a directory in the location path")
202+
warnings.warn("There is already a directory in the location path", stacklevel=2)
202203

203204
# === setup sdata reader/writer ===
204205
self.filehandler = sdata_filehandler(
@@ -499,8 +500,9 @@ def _check_image_dtype(self, image: np.ndarray) -> None:
499500
"""
500501

501502
if not image.dtype == self.DEFAULT_IMAGE_DTYPE:
502-
Warning(
503-
f"Image dtype is not {self.DEFAULT_IMAGE_DTYPE} but insteadt {image.dtype}. The workflow expects images to be of dtype {self.DEFAULT_IMAGE_DTYPE}. Proceeding with the incorrect dtype can lead to unexpected results."
503+
warnings.warn(
504+
f"Image dtype is not {self.DEFAULT_IMAGE_DTYPE} but insteadt {image.dtype}. The workflow expects images to be of dtype {self.DEFAULT_IMAGE_DTYPE}. Proceeding with the incorrect dtype can lead to unexpected results.",
505+
stacklevel=2,
504506
)
505507
self.log(
506508
f"Image dtype is not {self.DEFAULT_IMAGE_DTYPE} but insteadt {image.dtype}. The workflow expects images to be of dtype {self.DEFAULT_IMAGE_DTYPE}. Proceeding with the incorrect dtype can lead to unexpected results."
@@ -638,7 +640,10 @@ def close_interactive_viewer(self):
638640

639641
def _check_for_interactive_session(self):
640642
if self.interactive is not None:
641-
Warning("Interactive viewer is still open. Will automatically close before proceeding with processing.")
643+
warnings.warn(
644+
"Interactive viewer is still open. Will automatically close before proceeding with processing.",
645+
stacklevel=2,
646+
)
642647
self.close_interactive_viewer()
643648

644649
#### Functions to visualize results ####
@@ -1452,8 +1457,9 @@ def load_input_from_sdata(
14521457
table = sdata_input[table_elem]
14531458
rename_columns = {}
14541459
if self.DEFAULT_CELL_ID_NAME in table.obs:
1455-
Warning(
1456-
f"Column {self.DEFAULT_CELL_ID_NAME} already exists in table. Renaming to `f{self.DEFAULT_CELL_ID_NAME}_orig` to preserve compatibility with scPortrait workflow."
1460+
warnings.warn(
1461+
f"Column {self.DEFAULT_CELL_ID_NAME} already exists in table. Renaming to `f{self.DEFAULT_CELL_ID_NAME}_orig` to preserve compatibility with scPortrait workflow.",
1462+
stacklevel=2,
14571463
)
14581464
rename_columns[self.DEFAULT_CELL_ID_NAME] = f"{self.DEFAULT_CELL_ID_NAME}_orig"
14591465
self.log(
@@ -1495,11 +1501,31 @@ def load_input_from_sdata(
14951501

14961502
self.filehandler._add_centers(segmentation_label=self.cyto_seg_name)
14971503

1504+
# read the sdata object from file to ensure we have access to all newly written elements
1505+
sdata = SpatialData.read(self.sdata_path)
1506+
14981507
# ensure that if both an nucleus and cytosol segmentation mask are loaded that they match
14991508
if self.nuc_seg_status and self.cyto_seg_status:
1500-
ids_nuc = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.nuc_seg_name}"].index.values)
1501-
ids_cyto = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.cyto_seg_name}"].index.values)
1502-
assert ids_nuc == ids_cyto, "Nucleus and cytosol segmentation masks do not match."
1509+
ids_nuc = set(sdata["centers_seg_all_nucleus"].index.compute().values)
1510+
ids_cyto = set(sdata["centers_seg_all_cytosol"].index.compute().values)
1511+
1512+
if ids_nuc.issubset(ids_cyto):
1513+
if not ids_nuc == ids_cyto:
1514+
warnings.warn(
1515+
"Nucleus segmentation mask is a subset of cytosol segmentation mask, but they do not match exactly. \n This means that cells exist which do not have a nucleus mask but do have a cytosol mask. \n Please be careful when configuring your extraction workflow.",
1516+
stacklevel=2,
1517+
)
1518+
1519+
elif ids_cyto.issubset(ids_nuc):
1520+
if not ids_cyto == ids_nuc:
1521+
warnings.warn(
1522+
"Cytosol segmentation mask is a subset of nucleus segmentation mask, but they do not match exactly. \n This means cells exist which do not have a cytosol mask but do have a nucleus mask. \n Please be careful when configuring your extraction workflow.",
1523+
stacklevel=2,
1524+
)
1525+
else:
1526+
raise ValueError(
1527+
"Nucleus and cytosol segmentation masks do not match. This is unexpected and should be investigated."
1528+
)
15031529

15041530
self.get_project_status()
15051531

src/scportrait/tools/sdata/write/_write.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,22 @@ def image(
7070
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
7171
Image2DModel().validate(image)
7272

73-
if channel_names is not None:
74-
warnings.warn(
75-
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray.",
76-
stacklevel=2,
77-
)
73+
# read channel names from the DataArray if not provided
74+
if channel_names is None:
75+
channel_names = image.coords["c"].values.tolist()
76+
else:
77+
if len(channel_names) != image.shape[0]:
78+
raise ValueError(
79+
f"Number of channel names ({len(channel_names)}) does not match the number of channels in the image ({image.shape[0]})."
80+
)
81+
channel_names_old = image.coords["c"].values.tolist()
82+
if channel_names_old != channel_names:
83+
warnings.warn(
84+
f"Channel names in the DataArray ({channel_names_old}) do not match the provided channel names ({channel_names}). The DataArray will be updated with the provided channel names.",
85+
stacklevel=2,
86+
)
7887

88+
# if so first validate the model since this means we are getting the image from a spatialdata object already
7989
if chunks is not None:
8090
warnings.warn(
8191
"Chunks are ignored when passing a single scale image in the DataArray format. Chunks are read directly from the DataArray.",
@@ -84,6 +94,7 @@ def image(
8494

8595
image = Image2DModel.parse(
8696
image,
97+
c_coords=channel_names,
8798
scale_factors=scale_factors,
8899
rgb=rgb,
89100
)

0 commit comments

Comments
 (0)