Skip to content

Commit 2982ce3

Browse files
committed
[TEST] refactor pytest fixtures to live in one file
1 parent fe37f85 commit 2982ce3

File tree

3 files changed

+68
-24
lines changed

3 files changed

+68
-24
lines changed

tests/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import shutil
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from anndata import AnnData
7+
from spatialdata import SpatialData
8+
from spatialdata.datasets import blobs
9+
10+
rng = np.random.default_rng()
11+
12+
13+
@pytest.fixture
14+
def h5sc_object() -> AnnData:
15+
# Two cells, two channels, small images
16+
cell_ids = [101, 102, 107, 109]
17+
n_cells = 4
18+
channel_names = np.array(["seg_all_nucleus", "ch0", "ch1"])
19+
channel_mapping = np.array(["mask", "image", "image"]) # or whatever mapping your code expects
20+
n_channels = len(channel_names)
21+
H, W = 10, 10
22+
23+
# --- obs ---
24+
obs = pd.DataFrame({"scportrait_cell_id": cell_ids}, index=np.arange(n_cells))
25+
26+
# --- var (channel metadata) ---
27+
var = pd.DataFrame(index=np.arange(n_channels).astype("str"))
28+
var["channels"] = channel_names
29+
var["channel_mapping"] = channel_mapping
30+
31+
adata = AnnData(obs=obs, var=var)
32+
adata.obsm["single_cell_images"] = rng.random((n_cells, n_channels, H, W))
33+
adata.uns["single_cell_images"] = {
34+
"channel_mapping": channel_mapping,
35+
"channel_names": channel_names,
36+
"compression": "lzf",
37+
"image_size": np.int64(H),
38+
"n_cells": np.int64(n_cells),
39+
"n_channels": np.int64(n_channels),
40+
"n_image_channels": np.int64(n_channels - 1),
41+
"n_masks": np.int64(1),
42+
}
43+
44+
yield adata
45+
46+
47+
@pytest.fixture()
48+
def sdata(tmp_path) -> SpatialData:
49+
sdata = blobs()
50+
# Write to temporary location
51+
sdata_path = tmp_path / "sdata.zarr"
52+
sdata.write(sdata_path)
53+
yield sdata
54+
shutil.rmtree(sdata_path)
55+
56+
57+
@pytest.fixture
58+
def sdata_with_labels() -> SpatialData:
59+
sdata = blobs()
60+
sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category")
61+
sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float)
62+
return sdata

tests/unit_tests/plotting/test_sdata.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010
from scportrait.plotting import sdata as plotting
1111

1212

13-
@pytest.fixture
14-
def sdata():
15-
sdata = blobs()
16-
sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category")
17-
sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float)
18-
return sdata # provides images and labels used in tests
19-
20-
2113
@pytest.mark.parametrize(
2214
"channel_names, palette, return_fig, show_fig",
2315
[
@@ -26,9 +18,9 @@ def sdata():
2618
([0, 1], None, False, False),
2719
],
2820
)
29-
def test_plot_image(sdata, channel_names, palette, return_fig, show_fig):
21+
def test_plot_image(sdata_with_labels, channel_names, palette, return_fig, show_fig):
3022
fig = plotting.plot_image(
31-
sdata=sdata,
23+
sdata=sdata_with_labels,
3224
image_name="blobs_image",
3325
channel_names=channel_names,
3426
palette=palette,
@@ -51,9 +43,9 @@ def test_plot_image(sdata, channel_names, palette, return_fig, show_fig):
5143
(None, None), # test only mask overlay without image
5244
],
5345
)
54-
def test_plot_segmentation_mask(sdata, selected_channels, background_image):
46+
def test_plot_segmentation_mask(sdata_with_labels, selected_channels, background_image):
5547
fig = plotting.plot_segmentation_mask(
56-
sdata=sdata,
48+
sdata=sdata_with_labels,
5749
masks=["blobs_labels"],
5850
background_image=background_image,
5951
selected_channels=selected_channels,
@@ -73,9 +65,9 @@ def test_plot_segmentation_mask(sdata, selected_channels, background_image):
7365
(True, "labelling_continous"),
7466
],
7567
)
76-
def test_plot_labels(sdata, vectorized, color):
68+
def test_plot_labels(sdata_with_labels, vectorized, color):
7769
fig = plotting.plot_labels(
78-
sdata=sdata,
70+
sdata=sdata_with_labels,
7971
label_layer="blobs_labels",
8072
vectorized=vectorized,
8173
color=color,

tests/unit_tests/tools/sdata/test_write.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@
77
from spatialdata.datasets import blobs
88

99

10-
@pytest.fixture()
11-
def sdata(tmp_path):
12-
sdata = blobs()
13-
# Write to temporary location
14-
sdata_path = tmp_path / "sdata.zarr"
15-
sdata.write(sdata_path)
16-
yield sdata
17-
shutil.rmtree(sdata_path)
18-
19-
2010
@pytest.fixture()
2111
def sdata_path(tmp_path):
2212
# Write to temporary location

0 commit comments

Comments
 (0)