Skip to content

Commit eb2dd12

Browse files
committed
Merge branch 'main' into pre-commit-ci-update-config
2 parents ddc8049 + a0449dd commit eb2dd12

File tree

8 files changed

+101
-4
lines changed

8 files changed

+101
-4
lines changed

src/scportrait/pipeline/extraction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from scportrait.pipeline._base import ProcessingStep
2323
from scportrait.pipeline._utils.helper import flatten
2424
from scportrait.processing.images._image_processing import percentile_normalization
25+
from scportrait.tools.sdata.write._helper import _normalize_anndata_strings
2526

2627

2728
class HDF5CellExtraction(ProcessingStep):
@@ -753,6 +754,7 @@ def _initialize_empty_anndata(self) -> None:
753754
adata.uns[f"{self.DEFAULT_NAME_SINGLE_CELL_IMAGES}/compression"] = self.compression_type
754755

755756
# write to file
757+
_normalize_anndata_strings(adata)
756758
adata.write(self.output_path)
757759

758760
def _create_output_files(self) -> None:

src/scportrait/pipeline/featurization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,15 +1582,15 @@ def calculate_statistics(self, img: torch.Tensor, n_masks: int = 2):
15821582
for channel in range(n_masks, img.shape[1]):
15831583
img_selected = img[:, channel]
15841584

1585-
for mask in masks:
1585+
for mask, area in zip(masks, mask_statistics, strict=True):
15861586
_img_selected = img_selected.masked_fill(~mask, torch.nan).to(torch.float32)
15871587

15881588
mean = _img_selected.view(N, -1).nanmean(1, keepdim=True)
15891589
median = _img_selected.view(N, -1).nanquantile(q=0.5, dim=1, keepdim=True)
15901590
quant75 = _img_selected.view(N, -1).nanquantile(q=0.75, dim=1, keepdim=True)
15911591
quant25 = _img_selected.view(N, -1).nanquantile(q=0.25, dim=1, keepdim=True)
1592-
summed_intensity = _img_selected.view(N, -1).sum(1, keepdim=True)
1593-
summed_intensity_area_normalized = summed_intensity / mask_statistics[-1]
1592+
summed_intensity = _img_selected.view(N, -1).nansum(1, keepdim=True)
1593+
summed_intensity_area_normalized = summed_intensity / area
15941594

15951595
# save results
15961596
channel_statistics.extend(

src/scportrait/pipeline/project.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
get_chunk_size,
3838
rechunk_image,
3939
)
40-
from scportrait.tools.sdata.write._helper import _get_image, _get_shape
40+
from scportrait.tools.sdata.write._helper import _get_image, _get_shape, _normalize_anndata_strings
4141

4242
if TYPE_CHECKING:
4343
from collections.abc import Callable
@@ -1473,6 +1473,7 @@ def load_input_from_sdata(
14731473
table.obs["region"] = new_name
14741474
table.obs["region"] = table.obs["region"].astype("category")
14751475
table.obs.rename(columns=rename_columns, inplace=True)
1476+
_normalize_anndata_strings(table)
14761477

14771478
if keep_all:
14781479
shutil.rmtree(self.sdata_path, ignore_errors=True)

src/scportrait/tools/sdata/write/_helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from typing import Any, Literal, TypeAlias
33

44
import numpy as np
5+
import pandas as pd
6+
from anndata import AnnData
7+
from pandas.api.types import is_string_dtype
58
from spatialdata import SpatialData, read_zarr
69
from spatialdata.models import get_model
710
from xarray import DataArray, DataTree
@@ -68,6 +71,32 @@ def _make_key_lookup(sdata: SpatialData) -> dict:
6871
return dict_lookup
6972

7073

74+
def _normalize_dataframe_strings(df: pd.DataFrame) -> None:
75+
"""Normalize string dtypes to object to avoid nullable string serialization issues."""
76+
if is_string_dtype(df.index.dtype):
77+
df.index = df.index.astype(object)
78+
if df.index.isna().any():
79+
df.index = df.index.where(~df.index.isna(), None)
80+
81+
string_cols = df.select_dtypes(include=["string"]).columns
82+
if len(string_cols) > 0:
83+
df[string_cols] = df[string_cols].astype(object)
84+
for col in string_cols:
85+
if df[col].isna().any():
86+
df[col] = df[col].where(df[col].notna(), None)
87+
88+
cat_cols = df.select_dtypes(include=["category"]).columns
89+
for col in cat_cols:
90+
if is_string_dtype(df[col].cat.categories.dtype):
91+
df[col] = df[col].cat.set_categories(df[col].cat.categories.astype(object))
92+
93+
94+
def _normalize_anndata_strings(adata: AnnData) -> None:
95+
"""Normalize obs/var string dtypes to python-backed storage."""
96+
_normalize_dataframe_strings(adata.obs)
97+
_normalize_dataframe_strings(adata.var)
98+
99+
71100
def _force_delete_object(sdata: SpatialData, name: str) -> None:
72101
"""Force delete an object from the SpatialData object and directory.
73102
@@ -110,6 +139,9 @@ def add_element_sdata(sdata: SpatialData, element: Any, element_name: str, overw
110139

111140
_force_delete_object(sdata, element_name)
112141

142+
if isinstance(element, AnnData):
143+
_normalize_anndata_strings(element)
144+
113145
# the element needs to validate with exactly one of the models
114146
get_model(element)
115147

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
import shutil
22

3+
import matplotlib
4+
import matplotlib.pyplot as plt
35
import numpy as np
46
import pandas as pd
57
import pytest
68
from anndata import AnnData
9+
from matplotlib.figure import Figure
710
from spatialdata import SpatialData
811
from spatialdata.datasets import blobs
912

13+
from scportrait.tools.sdata.write._helper import _normalize_anndata_strings
14+
1015
rng = np.random.default_rng()
1116

1217

18+
@pytest.fixture(autouse=True)
19+
def _disable_matplotlib_show(monkeypatch):
20+
# Force a non-interactive backend for e2e runs
21+
matplotlib.use("Agg", force=True)
22+
23+
# Disable any implicit rendering during tests
24+
monkeypatch.setattr(plt, "show", lambda *args, **kwargs: None)
25+
monkeypatch.setattr(Figure, "show", lambda *args, **kwargs: None)
26+
27+
1328
@pytest.fixture
1429
def h5sc_object() -> AnnData:
1530
# Two cells, two channels, small images
@@ -47,6 +62,7 @@ def h5sc_object() -> AnnData:
4762
@pytest.fixture()
4863
def sdata(tmp_path) -> SpatialData:
4964
sdata = blobs()
65+
_normalize_anndata_strings(sdata["table"])
5066
# Write to temporary location
5167
sdata_path = tmp_path / "sdata.zarr"
5268
sdata.write(sdata_path)
@@ -57,6 +73,7 @@ def sdata(tmp_path) -> SpatialData:
5773
@pytest.fixture
5874
def sdata_with_labels() -> SpatialData:
5975
sdata = blobs()
76+
_normalize_anndata_strings(sdata["table"])
6077
sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category")
6178
sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float)
6279
return sdata

tests/e2e_tests/test_data_loaders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
from scportrait.data._datasets import dataset_1_omezarr
1010
from scportrait.pipeline.project import Project
11+
from scportrait.tools.sdata.write._helper import _normalize_anndata_strings
1112

1213

1314
@pytest.fixture()
1415
def sdata_path(tmp_path):
1516
sdata = blobs()
17+
_normalize_anndata_strings(sdata["table"])
1618
# Write to temporary location
1719
sdata_path = tmp_path / "sdata.zarr"
1820
sdata.write(sdata_path)

tests/unit_tests/pipeline/test_featurization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
import torch
34

@@ -51,6 +52,21 @@ def test_cell_featurizer(tmp_path):
5152
assert nuc_mean > cyto_mean
5253
assert abs(nuc_mean - cyto_mean) > 1.0
5354

55+
# Ensure masks are applied and sums are finite (no NaNs from masked sum).
56+
assert not np.isnan(feat_map["ch0_summed_intensity_nucleus"])
57+
assert not np.isnan(feat_map["ch0_summed_intensity_cytosol"])
58+
assert not np.isnan(feat_map["ch0_summed_intensity_cytosol_only"])
59+
60+
# Check exact expected values for this synthetic setup.
61+
assert feat_map["ch0_summed_intensity_nucleus"] == pytest.approx(10.0)
62+
assert feat_map["ch0_summed_intensity_cytosol"] == pytest.approx(13.0)
63+
assert feat_map["ch0_summed_intensity_cytosol_only"] == pytest.approx(3.0)
64+
65+
# Area-normalized sums should use each mask's own area (not the last mask).
66+
assert feat_map["ch0_summed_intensity_area_normalized_nucleus"] == pytest.approx(10.0)
67+
assert feat_map["ch0_summed_intensity_area_normalized_cytosol"] == pytest.approx(13.0 / 4.0)
68+
assert feat_map["ch0_summed_intensity_area_normalized_cytosol_only"] == pytest.approx(1.0)
69+
5470

5571
def test_mask_bool_integrity_not_corrupted():
5672
mask = torch.tensor([[True, False], [False, True]])

tests/unit_tests/tools/sdata/test_write.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import shutil
22

33
import numpy as np
4+
import pandas as pd
45
import pytest
56
import scportrait.tl.sdata.write as write
7+
from anndata import AnnData
68
from spatialdata import SpatialData, read_zarr
79
from spatialdata.datasets import blobs
810

@@ -59,6 +61,31 @@ def test_add_element_sdata(sdata, sdata_path, element_name):
5961
assert element_name in sdata
6062

6163

64+
def _make_table_from_dataframe() -> AnnData:
65+
obs = pd.DataFrame(
66+
{
67+
"region": pd.Series(["1", "2", "3"], dtype="string"),
68+
"group": pd.Series(["a", "b", "a"], dtype="string").astype("category"),
69+
},
70+
index=pd.Index(["1", "2", "3"], dtype="string"),
71+
)
72+
var = pd.DataFrame(index=pd.Index(["feature_0"], dtype="string"))
73+
X = np.zeros((3, 1), dtype=np.float32)
74+
return AnnData(X=X, obs=obs, var=var)
75+
76+
77+
def test_add_element_sdata_table_from_dataframe(sdata_path):
78+
sdata_new = SpatialData()
79+
sdata_new.write(sdata_path)
80+
81+
table = _make_table_from_dataframe()
82+
write._helper.add_element_sdata(sdata_new, table, "table", overwrite=True)
83+
84+
assert "table" in sdata_new
85+
sdata = read_zarr(sdata_path)
86+
assert "table" in sdata
87+
88+
6289
### test scportrait.tools.sdata.write._helper.rename_image_element
6390
@pytest.mark.parametrize(
6491
"old_name, new_name",

0 commit comments

Comments
 (0)