Skip to content

Commit 0435c6e

Browse files
Merge pull request #268 from MannLabs/improve_plotting
[FEATURE] add control over fontsize
2 parents 7eeca55 + 3d51c97 commit 0435c6e

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/scportrait/pipeline/project.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ def plot_input_image(
617617
channels: list[int] | list[str] | None = None,
618618
normalize: bool = False,
619619
normalization_percentile: tuple[float, float] = (0.01, 0.99),
620+
fontsize: int = 20,
621+
figsize_single_tile=(8, 8),
620622
return_fig: bool = False,
621623
image_name="input_image",
622624
) -> Figure | None:
@@ -626,6 +628,8 @@ def plot_input_image(
626628
max_size: Maximum size of the image to be plotted in pixels.
627629
select_region: Tuple containing the x and y coordinates of the center of the region to be plotted. If not set it will use the center of the image.
628630
channels: List of channel names or indices to be plotted. If not set, the first 4 channels will be plotted.
631+
fontsize: Fontsize of the title of the plot.
632+
figsize_single_tile: Size of the single tile in the plot.
629633
return_fig: If set to ``True``, the function returns the figure object instead of displaying it.
630634
631635
Returns:
@@ -731,16 +735,18 @@ def plot_input_image(
731735
percentile_normalization(im, lower_percentile, upper_percentile) * np.iinfo(np.uint16).max
732736
).astype(np.uint16)
733737

734-
fig, axs = plt.subplots(1, len(channel_names) + 1, figsize=(8 * (len(channel_names) + 1), 8))
735-
_sdata.pl.render_images(image_name, channel=channel_names, palette=palette).pl.show(
736-
ax=axs[0], title="overlayed"
737-
)
738+
fig_size_x, fig_size_y = figsize_single_tile
739+
fig, axs = plt.subplots(1, len(channel_names) + 1, figsize=(fig_size_x * (len(channel_names) + 1), fig_size_y))
740+
_sdata.pl.render_images(image_name, channel=channel_names, palette=palette).pl.show(ax=axs[0])
741+
axs[0].set_title("overlayed", fontsize=fontsize)
738742
axs[0].axis("off")
739743

740744
for i, channel in enumerate(channel_names):
741745
_sdata.pl.render_images(image_name, channel=channel, palette=palette[i]).pl.show(
742-
ax=axs[i + 1], colorbar=False, title=channel
746+
ax=axs[i + 1],
747+
colorbar=False,
743748
)
749+
axs[i + 1].set_title(channel, fontsize=fontsize)
744750
axs[i + 1].axis("off")
745751
fig.tight_layout()
746752

@@ -756,6 +762,7 @@ def plot_he_image(
756762
max_width: int | None = None,
757763
select_region: tuple[int, int] | None = None,
758764
return_fig: bool = False,
765+
fontsize: int = 20,
759766
) -> None | Figure:
760767
"""Plot the hematoxylin and eosin (HE) channel of the input image.
761768
@@ -832,6 +839,7 @@ def plot_segmentation_masks(
832839
normalization_percentile: tuple[float, float] = (0.01, 0.99),
833840
image_name: str = "input_image",
834841
mask_names: list[str] | None = None,
842+
fontsize: int = 20,
835843
return_fig: bool = False,
836844
) -> None | Figure:
837845
"""Plot the generated segmentation masks. If the image is large it will automatically plot a subset cropped to the center of the spatialdata object.
@@ -901,7 +909,9 @@ def plot_segmentation_masks(
901909

902910
# create plot
903911
fig, axs = plt.subplots(1, len(masks) + 1, figsize=(8 * (len(masks) + 1), 8))
904-
plot_segmentation_mask(_sdata, masks, max_width=max_width, axs=axs[0], title="overlayed", show_fig=False)
912+
plot_segmentation_mask(
913+
_sdata, masks, max_width=max_width, axs=axs[0], title="overlayed", font_size=fontsize, show_fig=False
914+
)
905915

906916
for mask in masks:
907917
idx = masks.index(mask)
@@ -923,6 +933,7 @@ def plot_segmentation_masks(
923933
selected_channels=channel,
924934
axs=axs[idx + 1],
925935
title=name,
936+
font_size=fontsize,
926937
show_fig=False,
927938
)
928939

@@ -935,9 +946,20 @@ def plot_segmentation_masks(
935946
return None
936947

937948
def plot_single_cell_images(
938-
self, n_cells: int | None = None, select_channel: int | None = None, return_fig: bool = False
949+
self,
950+
n_cells: int | None = None,
951+
cell_ids: list[int] | None = None,
952+
select_channel: int | None = None,
953+
return_fig: bool = False,
939954
) -> None | Figure:
940-
return cell_grid(self.h5sc, n_cells=n_cells, select_channel=select_channel, return_fig=return_fig)
955+
if cell_ids is not None:
956+
assert n_cells is None, "n_cells and cell_ids cannot be set at the same time."
957+
if n_cells is not None:
958+
assert cell_ids is None, "n_cells and cell_ids cannot be set at the same time."
959+
960+
return cell_grid(
961+
self.h5sc, n_cells=n_cells, cell_ids=cell_ids, select_channel=select_channel, return_fig=return_fig
962+
)
941963

942964
#### Functions to load input data ####
943965
def load_input_from_array(

src/scportrait/plotting/sdata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def plot_segmentation_mask(
7575
selected_channels: int | list[int] | None = None,
7676
select_region: tuple[int, int] | None = None,
7777
axs: plt.Axes | None = None,
78+
font_size: int = 20,
7879
return_fig: bool = False,
7980
show_fig: bool = True,
8081
) -> plt.Figure | None:
@@ -148,7 +149,7 @@ def plot_segmentation_mask(
148149

149150
# turn off axis
150151
axs.axis("off")
151-
axs.set_title(title)
152+
axs.set_title(title, fontsize=font_size)
152153

153154
# return elements
154155
if return_fig:

0 commit comments

Comments
 (0)