Skip to content

Commit 9b15ebe

Browse files
Merge pull request #367 from MannLabs/bump_spatialdata_version
[VERSION] bumped spatialdata-plot version to 0.2.14 to fix incorrect channel scaling in default plotting
2 parents cce9b89 + be7ea87 commit 9b15ebe

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ spatialdata>=0.3.0,<0.6
88
pyarrow<22.0.0
99
py-lmd>=1.3.1
1010

11-
spatialdata-plot<=0.2.11
11+
spatialdata-plot>=0.2.14
1212
matplotlib
1313

1414
tifffile

src/scportrait/plotting/sdata.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def _get_shape_element(sdata, element_name) -> tuple[int, int]:
5555
_, x, y = shape
5656
elif len(shape) == 2:
5757
x, y = shape
58+
else:
59+
raise ValueError(f"Unsupported shape for element '{element_name}': expected 2D or 3D array, got {shape}.")
5860
return x, y
5961

6062

@@ -115,6 +117,7 @@ def plot_image(
115117
if ax is not None:
116118
if dpi is not None:
117119
warnings.warn("DPI is ignored when an axis is provided.", stacklevel=2)
120+
fig = ax.figure
118121
else:
119122
# get size of spatialdata object to plot (required for calculating figure size if DPI is set)
120123
x, y = _get_shape_element(sdata, image_name)
@@ -189,6 +192,7 @@ def plot_segmentation_mask(
189192
if ax is not None:
190193
if dpi is not None:
191194
warnings.warn("DPI is ignored when an axis is provided.", stacklevel=2)
195+
fig = ax.figure
192196
else:
193197
# get size of spatialdata object to plot (required for calculating figure size if DPI is set)
194198
x, y = _get_shape_element(sdata, masks[0])
@@ -217,20 +221,27 @@ def plot_segmentation_mask(
217221
if selected_channels is not None:
218222
if not isinstance(selected_channels, Iterable):
219223
selected_channels = [selected_channels]
224+
if any(i < 0 or i >= len(channel_names) for i in selected_channels):
225+
raise ValueError(
226+
f"selected_channels contains out-of-range indices for background image '{background_image}'."
227+
)
228+
if len(selected_channels) > len(PALETTE):
229+
raise ValueError("selected_channels has more entries than the available palette length.")
220230
channel_names = [channel_names[i] for i in selected_channels]
221231
c = len(channel_names)
222-
palette = [PALETTE[x] for x in selected_channels]
232+
palette = PALETTE[:c]
223233
else:
224234
if c > max_channels_to_plot:
225-
c = 4
235+
c = min(c, max_channels_to_plot)
226236
palette = PALETTE[:c]
227237
channel_names = list(channel_names[:c])
228238

229239
sdata.pl.render_images(background_image, channel=channel_names, palette=palette).pl.show(ax=ax, colorbar=False)
230240

231241
# plot selected segmentation masks
232242
for mask in masks:
233-
assert mask in sdata, f"Mask {mask} not found in sdata object."
243+
if mask not in sdata:
244+
raise KeyError(f"Mask {mask} not found in sdata object.")
234245
if f"{mask}_vectorized" not in sdata:
235246
sdata[f"{mask}_vectorized"] = spatialdata.to_polygons(sdata[mask])
236247
sdata.pl.render_shapes(
@@ -298,6 +309,7 @@ def plot_shapes(
298309
if ax is not None:
299310
if dpi is not None:
300311
warnings.warn("DPI is ignored when an axis is provided.", stacklevel=2)
312+
fig = ax.figure
301313
else:
302314
# get size of spatialdata object to plot (required for calculating figure size if DPI is set)
303315
x, y = _get_shape_element(sdata, shapes_layer)
@@ -306,7 +318,8 @@ def plot_shapes(
306318
fig, ax = _create_figure_dpi(x=x, y=y, dpi=dpi)
307319

308320
# plot selected shapes layer
309-
assert shapes_layer in sdata, f"Shapes layer {shapes_layer} not found in sdata object."
321+
if shapes_layer not in sdata:
322+
raise KeyError(f"Shapes layer {shapes_layer} not found in sdata object.")
310323

311324
sdata.pl.render_shapes(
312325
f"{shapes_layer}",
@@ -374,6 +387,7 @@ def plot_labels(
374387
if ax is not None:
375388
if dpi is not None:
376389
warnings.warn("DPI is ignored when an axis is provided.", stacklevel=2)
390+
fig = ax.figure
377391
else:
378392
# get size of spatialdata object to plot (required for calculating figure size if DPI is set)
379393
x, y = _get_shape_element(sdata, label_layer)
@@ -410,18 +424,25 @@ def plot_labels(
410424
annotating_table = spatialdata.models.TableModel.parse(annotating_table)
411425
break
412426
if found_annotation is not None:
427+
had_annotation = "_annotation" in sdata
428+
prev_annotation = sdata["_annotation"] if had_annotation else None
413429
sdata["_annotation"] = annotating_table
414-
sdata.pl.render_shapes(
415-
f"{label_layer}_vectorized",
416-
color=color,
417-
fill_alpha=fill_alpha,
418-
outline_alpha=0,
419-
cmap=cmap,
420-
palette=palette,
421-
groups=groups,
422-
norm=norm,
423-
).pl.show(ax=ax)
424-
del sdata["_annotation"] # delete element again after plotting
430+
try:
431+
sdata.pl.render_shapes(
432+
f"{label_layer}_vectorized",
433+
color=color,
434+
fill_alpha=fill_alpha,
435+
outline_alpha=0,
436+
cmap=cmap,
437+
palette=palette,
438+
groups=groups,
439+
norm=norm,
440+
).pl.show(ax=ax)
441+
finally:
442+
if had_annotation:
443+
sdata["_annotation"] = prev_annotation
444+
else:
445+
del sdata["_annotation"] # delete element again after plotting
425446
else:
426447
try:
427448
sdata.pl.render_labels(

tests/unit_tests/plotting/test_sdata.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def test_plot_image(sdata_with_labels, channel_names, palette, return_fig, show_
3535
assert fig is None
3636

3737

38+
def test_plot_image_with_ax_returns_fig(sdata_with_labels):
39+
fig, ax = plt.subplots()
40+
returned = plotting.plot_image(
41+
sdata=sdata_with_labels,
42+
image_name="blobs_image",
43+
channel_names=[0],
44+
palette=["red"],
45+
return_fig=True,
46+
show_fig=False,
47+
ax=ax,
48+
)
49+
assert returned is fig
50+
plt.close(fig)
51+
52+
3853
@pytest.mark.parametrize(
3954
"selected_channels, background_image",
4055
[
@@ -56,6 +71,33 @@ def test_plot_segmentation_mask(sdata_with_labels, selected_channels, background
5671
plt.close(fig)
5772

5873

74+
def test_plot_segmentation_mask_selected_channels_out_of_range(sdata_with_labels):
75+
with pytest.raises(ValueError):
76+
plotting.plot_segmentation_mask(
77+
sdata=sdata_with_labels,
78+
masks=["blobs_labels"],
79+
background_image="blobs_image",
80+
selected_channels=[999],
81+
return_fig=False,
82+
show_fig=False,
83+
)
84+
85+
86+
def test_plot_segmentation_mask_with_ax_returns_fig(sdata_with_labels):
87+
fig, ax = plt.subplots()
88+
returned = plotting.plot_segmentation_mask(
89+
sdata=sdata_with_labels,
90+
masks=["blobs_labels"],
91+
background_image="blobs_image",
92+
selected_channels=[0],
93+
return_fig=True,
94+
show_fig=False,
95+
ax=ax,
96+
)
97+
assert returned is fig
98+
plt.close(fig)
99+
100+
59101
@pytest.mark.parametrize(
60102
"vectorized, color",
61103
[
@@ -76,3 +118,31 @@ def test_plot_labels(sdata_with_labels, vectorized, color):
76118
)
77119
assert isinstance(fig, plt.Figure)
78120
plt.close(fig)
121+
122+
123+
def test_plot_labels_with_ax_returns_fig(sdata_with_labels):
124+
fig, ax = plt.subplots()
125+
returned = plotting.plot_labels(
126+
sdata=sdata_with_labels,
127+
label_layer="blobs_labels",
128+
vectorized=False,
129+
color="labelling_categorical",
130+
return_fig=True,
131+
show_fig=False,
132+
ax=ax,
133+
)
134+
assert returned is fig
135+
plt.close(fig)
136+
137+
138+
def test_plot_shapes_with_ax_returns_fig(sdata_with_labels):
139+
fig, ax = plt.subplots()
140+
returned = plotting.plot_shapes(
141+
sdata=sdata_with_labels,
142+
shapes_layer="blobs_polygons",
143+
return_fig=True,
144+
show_fig=False,
145+
ax=ax,
146+
)
147+
assert returned is fig
148+
plt.close(fig)

0 commit comments

Comments
 (0)