Skip to content

Commit c5e80ef

Browse files
Merge pull request #372 from MannLabs/add_image_normalization
improve image normalization capabilities
2 parents 973647c + 759e96c commit c5e80ef

File tree

18 files changed

+909
-199
lines changed

18 files changed

+909
-199
lines changed

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setup(app):
3333
# ones.
3434
extensions = [
3535
"sphinx.ext.autodoc",
36+
"sphinx.ext.autosummary",
3637
"sphinx.ext.napoleon",
3738
"sphinx.ext.doctest",
3839
"sphinxarg.ext",
@@ -95,6 +96,7 @@ def setup(app):
9596
"member-order": "bysource",
9697
}
9798
autoclass_content = "both"
99+
autosummary_generate = True
98100

99101
html_favicon = "favicon.png"
100102
html_logo = "_static/scPortrait_logo_light.svg"

docs/pages/module/plotting.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,22 @@
22
plotting
33
*******************
44

5-
.. automodule:: scportrait.plotting
5+
.. autosummary::
6+
:toctree: generated
7+
:nosignatures:
8+
9+
scportrait.plotting.add_scalebar
10+
scportrait.plotting.colorize
11+
scportrait.plotting.generate_composite
12+
13+
sdata
14+
=====
15+
16+
.. automodule:: scportrait.plotting.sdata
17+
:members:
18+
19+
h5sc
20+
====
21+
22+
.. automodule:: scportrait.plotting.h5sc
623
:members:

src/scportrait/_utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Utility helpers for scPortrait."""
2+
3+
from .deprecation import deprecated
4+
5+
__all__ = ["deprecated"]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Helpers for deprecating public APIs."""
2+
3+
from __future__ import annotations
4+
5+
import functools
6+
import warnings
7+
from collections.abc import Callable
8+
from typing import TypeVar
9+
10+
T = TypeVar("T", bound=Callable[..., object])
11+
12+
try:
13+
from deprecation import deprecated as _deprecated
14+
except (ImportError, ModuleNotFoundError): # pragma: no cover - optional dependency
15+
_deprecated = None
16+
17+
18+
def deprecated(*args, **kwargs):
19+
"""Return a deprecation decorator.
20+
21+
If the optional `deprecation` dependency is installed, this proxies to
22+
`deprecation.deprecated`. Otherwise it falls back to a lightweight wrapper
23+
that emits a DeprecationWarning at call time.
24+
"""
25+
if _deprecated is not None:
26+
return _deprecated(*args, **kwargs)
27+
28+
details = kwargs.get("details", "This function is deprecated and will be removed in a future release.")
29+
30+
def _decorator(func: T) -> T:
31+
@functools.wraps(func)
32+
def _wrapped(*f_args, **f_kwargs):
33+
warnings.warn(details, DeprecationWarning, stacklevel=2)
34+
return func(*f_args, **f_kwargs)
35+
36+
return _wrapped # type: ignore[return-value]
37+
38+
return _decorator

src/scportrait/pipeline/_utils/segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from skimage.transform import resize
1818

1919
from scportrait.pipeline._utils.constants import DEFAULT_SEGMENTATION_DTYPE
20-
from scportrait.plotting.vis import plot_image
20+
from scportrait.plotting._vis import plot_image_array
2121

2222

2323
def global_otsu(image: NDArray) -> float:
@@ -72,7 +72,7 @@ def _segment_threshold(
7272
image_mask = image > threshold
7373

7474
if debug:
75-
plot_image(image_mask, cmap="Greys_r")
75+
plot_image_array(image_mask, cmap="Greys_r")
7676

7777
image_mask_clean = binary_erosion(image_mask, footprint=disk(speckle_kernel))
7878
image_mask_clean = sk_dilation(image_mask_clean, footprint=disk(speckle_kernel - 1))

src/scportrait/pipeline/project.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,12 @@ def plot_input_image(
704704

705705
# subset spatialdata object if its too large
706706
if x > max_width or y > max_width:
707-
_sdata = get_bounding_box_sdata(_sdata, max_width, center_x, center_y)
707+
_sdata = get_bounding_box_sdata(
708+
_sdata,
709+
max_width,
710+
center_y=center_y,
711+
center_x=center_x,
712+
)
708713

709714
if normalize:
710715
lower_percentile, upper_percentile = normalization_percentile
@@ -902,7 +907,7 @@ def plot_segmentation_masks(
902907

903908
# subset spatialdata object if its too large
904909
if x > max_width or y > max_width:
905-
_sdata = get_bounding_box_sdata(_sdata, max_width, center_x, center_y)
910+
_sdata = get_bounding_box_sdata(_sdata, max_width, center_x=center_x, center_y=center_y)
906911

907912
if normalize:
908913
lower_percentile, upper_percentile = normalization_percentile

src/scportrait/pipeline/segmentation/workflows/_cellpose.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,6 @@ def _get_kernel_config(self):
306306
else:
307307
self.kernel_size = 20 # default value
308308

309-
print(self.kernel_size)
310-
311309
def _expand_nucleus_mask(self, nucleus_mask: np.ndarray, kernel_size: int) -> np.ndarray:
312310
"""
313311
Expands the nucleus mask by a given kernel size using dilation
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._utils import add_scalebar
2+
from ._vis import colorize, generate_composite
23
from .h5sc import cell_grid, cell_grid_multi_channel, cell_grid_single_channel
3-
from .sdata import plot_image, plot_labels, plot_segmentation_mask
4-
from .vis import generate_composite
4+
from .sdata import plot_image, plot_labels, plot_segmentation_mask, plot_shapes
55

66
__all__ = [
77
"add_scalebar",
@@ -11,5 +11,7 @@
1111
"plot_segmentation_mask",
1212
"plot_image",
1313
"plot_labels",
14+
"plot_shapes",
1415
"generate_composite",
16+
"colorize",
1517
]
Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from scportrait.pipeline.project import Project
1313

1414

15-
def plot_image(
15+
from scportrait._utils.deprecation import deprecated
16+
17+
18+
def plot_image_array(
1619
array: np.ndarray,
1720
size: tuple[int, int] = (10, 10),
1821
save_name: str | None = "",
@@ -34,11 +37,11 @@ def plot_image(
3437
**kwargs: Additional keyword arguments to be passed to `ax.imshow`.
3538
3639
Returns:
37-
None: The function will display the image but does not return any values.
40+
Matplotlib figure if `return_fig=True`, otherwise `None`.
3841
3942
Example:
4043
>>> array = np.random.rand(10, 10)
41-
>>> plot_image(array, size=(5, 5))
44+
>>> plot_image_array(array, size=(5, 5))
4245
"""
4346

4447
fig = plt.figure(frameon=False)
@@ -60,12 +63,25 @@ def plot_image(
6063
plt.close()
6164

6265

66+
@deprecated(
67+
deprecated_in="1.6.2",
68+
removed_in="1.7.0",
69+
details=(
70+
"This function is not used internally and will be removed in a future release. "
71+
"Prefer scportrait.plotting.sdata.plot_labels or scportrait.plotting.sdata.plot_shapes."
72+
),
73+
)
6374
def visualize_class(
6475
class_ids: np.ndarray | list[int], seg_map: np.ndarray, image: np.ndarray, all_ids=None, return_fig=False, **kwargs
6576
):
6677
"""
6778
Visualize specific classes in a segmentation map by highlighting them on top of a background image.
6879
80+
.. deprecated:: 1.6.2
81+
This function is not used internally and will be removed in a future release. Prefer
82+
`scportrait.plotting.sdata.plot_labels` or `scportrait.plotting.sdata.plot_shapes` for
83+
SpatialData-based workflows.
84+
6985
This function takes in class IDs and a segmentation map, and creates an output visualization
7086
where the specified classes are highlighted on top of the provided background image.
7187
@@ -106,12 +122,20 @@ def visualize_class(
106122

107123
vis_map = label2rgb(outmap, image=image, colors=["red", "blue"], alpha=0.4, bg_label=0)
108124

109-
fig = plot_image(vis_map, return_fig=True, **kwargs)
125+
fig = plot_image_array(vis_map, return_fig=True, **kwargs)
110126

111127
if return_fig:
112128
return fig
113129

114130

131+
@deprecated(
132+
deprecated_in="1.6.2",
133+
removed_in="1.7.0",
134+
details=(
135+
"This helper is superseded by scportrait.plotting.sdata.plot_segmentation_mask and "
136+
"is not used internally. It will be removed in a future release."
137+
),
138+
)
115139
def plot_segmentation_mask(
116140
project: Project,
117141
mask_channel: int = 0,
@@ -124,6 +148,10 @@ def plot_segmentation_mask(
124148
) -> object:
125149
"""Visualize the segmentation mask overlayed with a channel of the input image.
126150
151+
.. deprecated:: 1.6.2
152+
This helper is superseded by `scportrait.plotting.sdata.plot_segmentation_mask` and
153+
is not used internally. It will be removed in a future release.
154+
127155
Args:
128156
project: Instance of a scPortrait project.
129157
mask_channel: The index of the channel to use for the segmentation mask.
@@ -167,17 +195,22 @@ def colorize(
167195
im: np.ndarray, color: tuple[int, ...] = (1, 0, 0), clip_percentile: float = 0.0, normalize_image: bool = False
168196
):
169197
"""
170-
Helper function to create an RGB image from a single-channel image using a
171-
specific color.
198+
Create an RGB image from a single-channel image using a specified color.
172199
173200
Args:
174-
im: A single-channel input image. If normalize_image = False, ensure that its values fall between the [0, 1] range.
201+
im: A single-channel input image. If normalize_image = False, ensure that its values fall between [0, 1].
175202
color: The color to use for the image. Defaults to (1, 0, 0).
176-
clip_percentile: The percentile to clip the image at rescaling. Defaults to 0.0 which is equivalent to Min-Max scaling.
177-
normalize_image: boolean value indicating if rescaling should be performed or not.
203+
clip_percentile: Percentile to clip the image at when rescaling. Defaults to 0.0 (min-max scaling).
204+
normalize_image: Whether to rescale the image before colorizing.
178205
179206
Returns:
180207
np.ndarray: The colorized image.
208+
209+
Example:
210+
>>> import numpy as np
211+
>>> from scportrait.plotting import colorize
212+
>>> im = np.random.rand(64, 64)
213+
>>> rgb = colorize(im, color=(0, 1, 0), normalize_image=True)
181214
"""
182215
# Check that we do just have a 2D image
183216
if im.ndim > 2 and im.shape[2] != 1:

0 commit comments

Comments
 (0)