Skip to content

Commit 44589e7

Browse files
Merge pull request #351 from MannLabs/fix_346
Fix 346
2 parents aa4ba4f + 2982ce3 commit 44589e7

File tree

6 files changed

+142
-31
lines changed

6 files changed

+142
-31
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ config_test.yml
88
segmentation_workflows/
99
scportrait_data/
1010

11+
#do not track vscode settings
12+
.vscode/
1113

1214
#do not track output generated by sphinx-gallery
1315
sg_execution_times.rst

src/scportrait/pipeline/segmentation/workflows/_model_caches.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,9 @@ def _make_zenodo_download_link(record_id: str, filename: str) -> str:
1212
"""
1313
Construct a direct download URL for a file stored in a Zenodo record.
1414
15-
Parameters
16-
----------
17-
record_id : str
18-
The Zenodo record identifier (e.g., "1234567").
19-
filename : str
20-
The exact filename stored in the Zenodo record (case sensitive).
15+
Args:
16+
record_id : The Zenodo record identifier (e.g., "1234567").
17+
filename : The exact filename stored in the Zenodo record (case sensitive).
2118
2219
Returns
2320
-------
@@ -28,7 +25,11 @@ def _make_zenodo_download_link(record_id: str, filename: str) -> str:
2825

2926

3027
def _scportrait_cache_model_path(basename: str) -> None:
31-
"""Download a model from a public Nextcloud share into Cellpose's model cache if missing."""
28+
"""Download a model from a public Zenodo share into Cellpose's model cache if missing.
29+
30+
Args:
31+
32+
"""
3233
MODEL_DIR.mkdir(parents=True, exist_ok=True)
3334

3435
url = _make_zenodo_download_link(

tests/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import shutil
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from anndata import AnnData
7+
from spatialdata import SpatialData
8+
from spatialdata.datasets import blobs
9+
10+
rng = np.random.default_rng()
11+
12+
13+
@pytest.fixture
14+
def h5sc_object() -> AnnData:
15+
# Two cells, two channels, small images
16+
cell_ids = [101, 102, 107, 109]
17+
n_cells = 4
18+
channel_names = np.array(["seg_all_nucleus", "ch0", "ch1"])
19+
channel_mapping = np.array(["mask", "image", "image"]) # or whatever mapping your code expects
20+
n_channels = len(channel_names)
21+
H, W = 10, 10
22+
23+
# --- obs ---
24+
obs = pd.DataFrame({"scportrait_cell_id": cell_ids}, index=np.arange(n_cells))
25+
26+
# --- var (channel metadata) ---
27+
var = pd.DataFrame(index=np.arange(n_channels).astype("str"))
28+
var["channels"] = channel_names
29+
var["channel_mapping"] = channel_mapping
30+
31+
adata = AnnData(obs=obs, var=var)
32+
adata.obsm["single_cell_images"] = rng.random((n_cells, n_channels, H, W))
33+
adata.uns["single_cell_images"] = {
34+
"channel_mapping": channel_mapping,
35+
"channel_names": channel_names,
36+
"compression": "lzf",
37+
"image_size": np.int64(H),
38+
"n_cells": np.int64(n_cells),
39+
"n_channels": np.int64(n_channels),
40+
"n_image_channels": np.int64(n_channels - 1),
41+
"n_masks": np.int64(1),
42+
}
43+
44+
yield adata
45+
46+
47+
@pytest.fixture()
48+
def sdata(tmp_path) -> SpatialData:
49+
sdata = blobs()
50+
# Write to temporary location
51+
sdata_path = tmp_path / "sdata.zarr"
52+
sdata.write(sdata_path)
53+
yield sdata
54+
shutil.rmtree(sdata_path)
55+
56+
57+
@pytest.fixture
58+
def sdata_with_labels() -> SpatialData:
59+
sdata = blobs()
60+
sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category")
61+
sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float)
62+
return sdata
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pytest
2+
import torch
3+
4+
from scportrait.pipeline.featurization import CellFeaturizer # adjust import to your project
5+
6+
7+
def test_cell_featurizer(tmp_path):
8+
# temp directory unique to this test
9+
out_dir = tmp_path / "featurization"
10+
out_dir.mkdir()
11+
12+
config = {
13+
"batch_size": 100,
14+
"dataloader_worker_number": 10,
15+
}
16+
17+
f = CellFeaturizer(
18+
config=config,
19+
directory=str(out_dir),
20+
project_location=None,
21+
overwrite=True,
22+
)
23+
24+
# Image: 1 batch, 1 channel, 4x4
25+
img = torch.zeros((1, 1, 4, 4), dtype=torch.float32)
26+
27+
# Define a "cell" region and a "nucleus" region inside it
28+
nucleus_mask = torch.zeros((1, 1, 4, 4), dtype=torch.float32)
29+
cell_mask = torch.zeros((1, 1, 4, 4), dtype=torch.float32)
30+
31+
cell_mask[..., 1:3, 1:3] = 1.0 # 2x2 cell block
32+
nucleus_mask[..., 1:2, 1:2] = 1.0 # 1x1 nucleus (top-left of cell)
33+
34+
# Intensities: nucleus=10, cytosol=1
35+
img[cell_mask.bool()] = 1.0
36+
img[nucleus_mask.bool()] = 10.0
37+
38+
# Build label stack
39+
labels = torch.cat([nucleus_mask, cell_mask], dim=1) # shape (1, 2, 4, 4)
40+
41+
# Run featurization on concatenated masks and images
42+
feats = f.calculate_statistics(torch.cat([labels, img], dim=1))
43+
column_names = f._generate_column_names(n_masks=2, channel_names=["ch0"])
44+
feat_map = dict(zip(column_names, feats[0].tolist(), strict=True))
45+
46+
# Key regression assertion:
47+
# If masks accidentally become all-True, these will match (or be very close).
48+
nuc_mean = feat_map["ch0_mean_nucleus"]
49+
cyto_mean = feat_map["ch0_mean_cytosol"]
50+
51+
assert nuc_mean > cyto_mean
52+
assert abs(nuc_mean - cyto_mean) > 1.0
53+
54+
55+
def test_mask_bool_integrity_not_corrupted():
56+
mask = torch.tensor([[True, False], [False, True]])
57+
before_true = mask.sum().item()
58+
59+
# This simulates the buggy behavior:
60+
# mask[mask == 0] = torch.nan # would cast nan->True and change mask
61+
# Instead, ensure your implementation never does this.
62+
63+
assert mask.dtype == torch.bool
64+
assert mask.sum().item() == before_true

tests/unit_tests/plotting/test_sdata.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010
from scportrait.plotting import sdata as plotting
1111

1212

13-
@pytest.fixture
14-
def sdata():
15-
sdata = blobs()
16-
sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category")
17-
sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float)
18-
return sdata # provides images and labels used in tests
19-
20-
2113
@pytest.mark.parametrize(
2214
"channel_names, palette, return_fig, show_fig",
2315
[
@@ -26,9 +18,9 @@ def sdata():
2618
([0, 1], None, False, False),
2719
],
2820
)
29-
def test_plot_image(sdata, channel_names, palette, return_fig, show_fig):
21+
def test_plot_image(sdata_with_labels, channel_names, palette, return_fig, show_fig):
3022
fig = plotting.plot_image(
31-
sdata=sdata,
23+
sdata=sdata_with_labels,
3224
image_name="blobs_image",
3325
channel_names=channel_names,
3426
palette=palette,
@@ -51,9 +43,9 @@ def test_plot_image(sdata, channel_names, palette, return_fig, show_fig):
5143
(None, None), # test only mask overlay without image
5244
],
5345
)
54-
def test_plot_segmentation_mask(sdata, selected_channels, background_image):
46+
def test_plot_segmentation_mask(sdata_with_labels, selected_channels, background_image):
5547
fig = plotting.plot_segmentation_mask(
56-
sdata=sdata,
48+
sdata=sdata_with_labels,
5749
masks=["blobs_labels"],
5850
background_image=background_image,
5951
selected_channels=selected_channels,
@@ -73,9 +65,9 @@ def test_plot_segmentation_mask(sdata, selected_channels, background_image):
7365
(True, "labelling_continous"),
7466
],
7567
)
76-
def test_plot_labels(sdata, vectorized, color):
68+
def test_plot_labels(sdata_with_labels, vectorized, color):
7769
fig = plotting.plot_labels(
78-
sdata=sdata,
70+
sdata=sdata_with_labels,
7971
label_layer="blobs_labels",
8072
vectorized=vectorized,
8173
color=color,

tests/unit_tests/tools/sdata/test_write.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@
77
from spatialdata.datasets import blobs
88

99

10-
@pytest.fixture()
11-
def sdata(tmp_path):
12-
sdata = blobs()
13-
# Write to temporary location
14-
sdata_path = tmp_path / "sdata.zarr"
15-
sdata.write(sdata_path)
16-
yield sdata
17-
shutil.rmtree(sdata_path)
18-
19-
2010
@pytest.fixture()
2111
def sdata_path(tmp_path):
2212
# Write to temporary location

0 commit comments

Comments
 (0)