Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/gallery/distribution/02_plot_dist_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
data,
kind="kde",
var_names=["mu"],
sample_dims=["draw"],
sample_dims=["draw"],
backend="none" # change to preferred backend
)
pc.add_title("KDE of μ by Chain (Centered Eight)")
pc.show()
1 change: 1 addition & 0 deletions docs/source/gallery/inference_diagnostics/01_plot_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
data,
backend="none" # change to preferred backend
)
pc.add_title("MCMC Sampling Traces: Centered Eight Model")
pc.show()
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
aes_by_visuals={"title": ["color"]}, # change title's color per variable
backend="none",
)
pc.add_title("Posterior Predictive Rootogram for Rugby Model")
pc.show()
1 change: 1 addition & 0 deletions src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
savefig,
scale_fig_size,
scatter,
set_figure_title,
set_ticklabel_visibility,
set_y_scale,
show,
Expand Down
52 changes: 49 additions & 3 deletions src/arviz_plots/backend/bokeh/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
import numpy as np
from bokeh.colors import Color
from bokeh.io.export import export_png, export_svg
from bokeh.layouts import GridBox, gridplot
from bokeh.layouts import GridBox, column, gridplot
from bokeh.models import (
BoxAnnotation,
ColumnDataSource,
CustomJSTickFormatter,
Div,
FixedTicker,
GridPlot,
Range1d,
Span,
Title,
)
from bokeh.models.css import Styles
from bokeh.plotting import figure as _figure
from bokeh.plotting import output_file, save
from bokeh.plotting import show as _show
Expand Down Expand Up @@ -234,6 +236,48 @@ def savefig(figure, path, **kwargs):
)


def set_figure_title(figure, text, *, color=unset, size=unset, **artist_kws):
"""Set a title for the entire figure.

Parameters
----------
figure : bokeh layout or None
The figure/layout to add the title to.
text : str
The title text.
color : optional
Color of the title text.
size : optional
Font size of the title.
**artist_kws : dict, optional
Additional keyword arguments passed to :class:`~bokeh.models.Div`.

Returns
-------
bokeh layout
The new layout with title added (column of title_div and original figure).
`~bokeh.models.Div`
The title Div element.
"""
if color is None:
color = unset
if size is None:
size = unset

styles = artist_kws.pop("styles", {})
if isinstance(styles, dict):
styles = Styles(**styles)
kwargs = {"color": color, "font_size": _float_or_str_size(size)}
kwargs = {key: value for key, value in kwargs.items() if value is not unset}
styles.update(**kwargs)
if styles.text_align is None:
styles.text_align = "center"

title_div = Div(text=text, styles=styles, **artist_kws)
new_layout = column(title_div, figure)
return new_layout, title_div


def get_figsize(plot_collection):
"""Get the size of the :term:`figure` element and its units."""
figure = plot_collection.viz["figure"].item()
Expand Down Expand Up @@ -290,6 +334,7 @@ def create_plotting_grid(
Whether to create plots with polar coordinate axes.
width_ratios, height_ratios : array-like, optional
Ratios between widths/heights of columns/rows in the generated :term:`plot` grid.
plot_hspace : float, optional
subplot_kws : dict, optional
Passed to :func:`~bokeh.plotting.figure`
**kwargs :
Expand Down Expand Up @@ -393,6 +438,7 @@ def create_plotting_grid(
if squeeze and figures.size == 1:
return None, figures[0, 0]
layout = gridplot(figures.tolist(), **kwargs)

return layout, figures.squeeze() if squeeze else figures


Expand All @@ -408,8 +454,8 @@ def _float_or_str_size(size):

Convert float sizes to string ones in px units.
"""
if size is unset:
return size
if size is unset or size is None:
return unset
if isinstance(size, str):
return size
return f"{size:.0f}px"
Expand Down
1 change: 1 addition & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
savefig,
scale_fig_size,
scatter,
set_figure_title,
set_ticklabel_visibility,
set_y_scale,
show,
Expand Down
38 changes: 37 additions & 1 deletion src/arviz_plots/backend/matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,38 @@ def savefig(figure, path, **kwargs):
figure.savefig(path, **kwargs)


def set_figure_title(figure, text, *, color=None, size=None, **artist_kws):
"""Set a title for the entire figure.

Parameters
----------
figure : `~matplotlib.figure.Figure`
The figure to add the title to.
text : str
The title text.
color : optional
Color of the title text.
size : optional
Font size of the title.
**artist_kws : dict, optional
Additional keyword arguments passed to :func:`~matplotlib.figure.Figure.suptitle`.

Returns
-------
`~matplotlib.figure.Figure`
The figure object (unchanged).
`~matplotlib.text.Text`
The title text object.
"""
kwargs = {}
if color is not None:
kwargs["color"] = color
if size is not None:
kwargs["fontsize"] = size
title_obj = figure.suptitle(text, **kwargs, **artist_kws)
return figure, title_obj


def get_figsize(plot_collection):
"""Get the size of the :term:`figure` element and its units."""
return plot_collection.viz["figure"].item().get_size_inches(), "inches"
Expand Down Expand Up @@ -237,7 +269,10 @@ def create_plotting_grid(
squeeze : bool, default True
sharex, sharey : bool, default False
polar : bool
subplot_kws : bool
width_ratios : list, optional
height_ratios : list, optional
plot_hspace : float, optional
subplot_kws : dict, optional
Passed to :func:`~matplotlib.pyplot.subplots` as ``subplot_kw``
**kwargs: dict, optional
Passed to :func:`~matplotlib.pyplot.subplots`
Expand Down Expand Up @@ -279,6 +314,7 @@ def create_plotting_grid(
for i, ax in enumerate(axes.ravel("C")):
if i >= number:
ax.set_axis_off()

return fig, axes


Expand Down
1 change: 1 addition & 0 deletions src/arviz_plots/backend/none/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
savefig,
scale_fig_size,
scatter,
set_figure_title,
set_ticklabel_visibility,
set_y_scale,
show,
Expand Down
32 changes: 32 additions & 0 deletions src/arviz_plots/backend/none/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,38 @@ def savefig(figure, path, **kwargs):
raise TypeError("'none' backend figures can't be saved.")


def set_figure_title(figure, text, *, color=None, size=None, **artist_kws):
"""Set a title for the entire figure.

Parameters
----------
figure : dict
The figure element dict.
text : str
The title text.
color : optional
Color of the title text.
size : optional
Font size of the title.
**artist_kws : dict, optional
Additional keyword arguments.

Returns
-------
dict
The figure element dict (unchanged).
dict
The title element dict.
"""
title_element = {"function": "set_figure_title", "text": text}
if color is not None:
title_element["color"] = color
if size is not None:
title_element["size"] = size
title_element.update(artist_kws)
return figure, title_element


def get_figsize(plot_collection):
"""Get the size of the :term:`figure` element and its units."""
figure_element = plot_collection.viz["figure"].item()
Expand Down
1 change: 1 addition & 0 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
savefig,
scale_fig_size,
scatter,
set_figure_title,
set_ticklabel_visibility,
set_y_scale,
show,
Expand Down
39 changes: 38 additions & 1 deletion src/arviz_plots/backend/plotly/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,39 @@ def savefig(figure, path, **kwargs):
figure.write_image(path, **kwargs)


def set_figure_title(figure, text, *, color=None, size=None, **artist_kws):
"""Set a title for the entire figure.

Parameters
----------
figure : `~plotly.graph_objects.Figure`
The figure to add the title to.
text : str
The title text.
color : optional
Color of the title text.
size : optional
Font size of the title.
**artist_kws : dict, optional
Additional keyword arguments passed to :meth:`~plotly.graph_objects.Figure.update_layout`.

Returns
-------
`~plotly.graph_objects.Figure`
The figure object (unchanged).
plotly title object
The title layout object from the figure.
"""
title_kwargs = {"text": text, "x": 0.5, "xanchor": "center"}
if color is not None:
title_kwargs["font_color"] = color
if size is not None:
title_kwargs["font_size"] = size
title_kwargs.update(artist_kws)
figure.update_layout(title=title_kwargs)
return figure, figure.layout.title


def create_plotting_grid(
number, # pylint: disable=unused-argument
rows=1,
Expand Down Expand Up @@ -334,7 +367,10 @@ def create_plotting_grid(
squeeze : bool, default True
sharex, sharey : bool, default False
polar : bool
subplot_kws : bool
width_ratios : list, optional
height_ratios : list, optional
plot_hspace : float, optional
subplot_kws : dict, optional
Ignored
**kwargs: dict, optional
Passed to :func:`~plotly.subplots.make_subplots`
Expand Down Expand Up @@ -380,6 +416,7 @@ def create_plotting_grid(
for row in range(rows):
for col in range(cols):
plots[row, col] = PlotlyPlot(figure, row + 1, col + 1)

if squeeze and plots.size == 1:
return figure, plots[0, 0]
return figure, plots.squeeze() if squeeze else plots
Expand Down
54 changes: 54 additions & 0 deletions src/arviz_plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,8 @@ def wrap(
Plotting backend.
figure_kwargs : mapping, optional
Passed to :func:`~.backend.create_plotting_grid` of the chosen plotting backend.
To add a figure title, use :meth:`~arviz_plots.PlotCollection.add_title` after
creating the PlotCollection.
**kwargs : mapping, optional
Passed as is to the initializer of ``PlotCollection``. That is,
used for ``aes`` and ``**kwargs`` arguments.
Expand Down Expand Up @@ -890,6 +892,8 @@ def grid(
Plotting backend.
figure_kwargs : mapping, optional
Passed to :func:`~.backend.create_plotting_grid` of the chosen plotting backend.
To add a figure title, use :meth:`~arviz_plots.PlotCollection.add_title` after
creating the PlotCollection.
**kwargs : mapping, optional
Passed as is to the initializer of ``PlotCollection``. That is,
used for ``aes`` and ``**kwargs`` arguments.
Expand Down Expand Up @@ -1276,6 +1280,56 @@ def store_in_artist_da(self, aux_artist, fun_label, var_name, sel):
"""Store the visual object of `var_name`+`sel` combination in `fun_label` variable."""
self.viz[fun_label][var_name].loc[sel] = aux_artist

def add_title(self, text, *, color=None, size=None, **artist_kws):
"""Add a title to the :term:`figure`.

Parameters
----------
text : str
The title text.
color : optional
Color of the title text.
size : optional
Font size of the title.
**artist_kws : mapping, optional
Additional keyword arguments passed to :func:`~.backend.set_figure_title`.

Examples
--------
Add a title after creating a plot:

.. jupyter-execute::

import arviz_base as azb
import arviz_plots as azp

data = azb.load_arviz_data("centered_eight")
pc = azp.plot_dist(data)
pc.add_title("Posterior Distributions")

Add a colored title with custom size:

.. jupyter-execute::

pc = azp.plot_trace(data, var_names=["mu"])
pc.add_title("MCMC Trace", color="darkblue", size=16)
"""
if "figure" not in self.viz.data_vars:
raise ValueError("No figure found to add title to")

plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
fig = self.viz["figure"].item()

new_fig, title_obj = plot_bknd.set_figure_title(
fig, text, color=color, size=size, **artist_kws
)

# bokeh returns a new column layout, so we need to update the stored figure
if new_fig is not fig:
self.viz["figure"] = xr.DataArray(new_fig)

self.viz["figure_title"] = xr.DataArray(title_obj)

def add_legend(
self,
dim,
Expand Down
Loading