Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
env:
CIBW_SKIP: 'pp*'
CIBW_ARCHS: 'auto64'
CIBW_MANYLINUX_X86_64_IMAGE: 'manylinux_2_28'
CIBW_PROJECT_REQUIRES_PYTHON: '>=3.10'
CIBW_TEST_REQUIRES: 'pytest'
defaults:
Expand Down
33 changes: 24 additions & 9 deletions ratapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def plot_contour(


def panel_plot_helper(
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
plot_func: Callable,
indices: list[int],
fig: matplotlib.figure.Figure | None = None,
progress_callback: Callable[[int, int], None] | None = None,
) -> matplotlib.figure.Figure:
"""Generate a panel-based plot from a single plot function.

Expand All @@ -994,6 +997,9 @@ def panel_plot_helper(
The list of indices to pass into ``plot_func``.
fig : matplotlib.figure.Figure, optional
The figure object to use for plot.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1005,21 +1011,21 @@ def panel_plot_helper(
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))

if fig is None:
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
else:
fig.clf()
fig.subplots(nrows, ncols)
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
axs = fig.get_axes()

current_plot = 0
for plot_num, index in enumerate(indices):
axs[plot_num].tick_params(which="both", labelsize="medium")
axs[plot_num].xaxis.offsetText.set_fontsize("small")
axs[plot_num].yaxis.offsetText.set_fontsize("small")
axs[plot_num].set_visible(True)
plot_func(axs[plot_num], index)

# blank unused plots
for i in range(nplots, len(axs)):
axs[i].set_visible(False)
if progress_callback is not None:
current_plot += 1
progress_callback(current_plot, nplots)

fig.tight_layout()
return fig
Expand All @@ -1036,6 +1042,7 @@ def plot_hists(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
**hist_settings,
):
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
Expand Down Expand Up @@ -1072,6 +1079,9 @@ def plot_hists(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots
hist_settings :
Settings passed to `np.histogram`. By default, the settings
passed are `bins = 25` and `density = True`.
Expand Down Expand Up @@ -1130,6 +1140,7 @@ def validate_dens_type(dens_type: str | None, param: str):
),
params,
fig,
progress_callback,
)
if return_fig:
return fig
Expand All @@ -1144,6 +1155,7 @@ def plot_chain(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
):
"""Plot the MCMC chain for each parameter of a Bayesian analysis.

Expand All @@ -1162,6 +1174,9 @@ def plot_chain(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1187,7 +1202,7 @@ def plot_one_chain(axes: Axes, i: int):
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
axes.set_title(results.fitNames[i], fontsize="small")

fig = panel_plot_helper(plot_one_chain, params, fig=fig)
fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
if return_fig:
return fig
plt.show(block=block)
Expand Down