Skip to content

Commit 558ce0c

Browse files
Add get_axis_by_index function (#313)
* add get_axis_by_index function * add function to init * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add inverse function param_indices_from_axis and rename axis_from_param_indices * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5797e1b commit 558ce0c

File tree

3 files changed

+135
-2
lines changed

3 files changed

+135
-2
lines changed

src/corner/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
11
# -*- coding: utf-8 -*-
22

3-
__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]
3+
__all__ = [
4+
"corner",
5+
"hist2d",
6+
"quantile",
7+
"overplot_lines",
8+
"overplot_points",
9+
"axis_from_param_indices",
10+
"param_indices_from_axis",
11+
]
412

5-
from corner.core import hist2d, overplot_lines, overplot_points, quantile
13+
from corner.core import (
14+
axis_from_param_indices,
15+
hist2d,
16+
overplot_lines,
17+
overplot_points,
18+
param_indices_from_axis,
19+
quantile,
20+
)
621
from corner.corner import corner
722
from corner.version import version as __version__

src/corner/core.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,84 @@ def overplot_points(fig, xs, reverse=False, **kwargs):
892892
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)
893893

894894

895+
def axis_from_param_indices(fig, ix, iy, return_axis=True):
896+
"""
897+
Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for
898+
manually adding additional data or labels to a specific axis. This is the inverse of
899+
`param_indices_from_axis`.
900+
901+
Parameters
902+
----------
903+
fig : Figure
904+
The figure generated by a call to :func:`corner.corner`.
905+
906+
ix, iy : int
907+
Indices of the parameter list corresponding to the plotted ``x`` and ``y`` axes. Only cases
908+
where ``ix <= iy`` have plotted axes, and ``ix == iy`` corresponds to the histogram axis for
909+
parameter index ``ix``. The function doesn't raise an error when ``ix > iy`` corresponding to one
910+
of the hidden axes, though it does raise an error if either ``ix`` or ``iy`` is too large for the
911+
dimensions of the plotted ``fig``.
912+
913+
return_axis : bool
914+
Return either the axis itself or its integer index
915+
916+
Returns
917+
-------
918+
ax : axis
919+
Entry in the ``fig.axes`` list.
920+
"""
921+
ndim = int(np.sqrt(len(fig.axes)))
922+
if ix > ndim - 1:
923+
msg = f"ix={ix} too large for ndim={ndim}"
924+
raise ValueError(msg)
925+
elif iy > ndim - 1:
926+
msg = f"ix={ix} too large for ndim={ndim}"
927+
raise ValueError(msg)
928+
929+
for i in range(ndim**2):
930+
ix_i = range(ndim)[(i % ndim)]
931+
iy_i = range(ndim)[(i // ndim) - ndim]
932+
933+
if (ix == ix_i) & (iy == iy_i):
934+
break
935+
936+
if return_axis:
937+
return fig.axes[i]
938+
else:
939+
return i
940+
941+
942+
def param_indices_from_axis(fig, i):
943+
"""
944+
Get indices ``ix``, ``iy`` of the input data associated with one of the plotted axes. This is the
945+
inverse of `axis_from_param_indices`.
946+
947+
Parameters
948+
----------
949+
fig : Figure
950+
The figure generated by a call to :func:`corner.corner`.
951+
952+
i : int
953+
Index of an entry in the ``fig.axes`` list
954+
955+
Returns
956+
-------
957+
ix, iy : int
958+
Indices of the figure axes list corresponding to the plotted ``x`` and ``y`` of the specified axis
959+
index
960+
"""
961+
if i > len(fig.axes):
962+
msg = f"{i} must be < len(fig.axes) = {len(fig.axes)}"
963+
raise ValueError(msg)
964+
965+
ndim = int(np.sqrt(len(fig.axes)))
966+
967+
ix = range(ndim)[(i % ndim)]
968+
iy = range(ndim)[(i // ndim) - ndim]
969+
970+
return ix, iy
971+
972+
895973
def _parse_input(xs):
896974
xs = np.atleast_1d(xs)
897975
if len(xs.shape) == 1:

tests/test_corner.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,46 @@ def test_basic():
6565
_run_corner()
6666

6767

68+
def test_axis_index():
69+
70+
labels = ["a", "b", "c"]
71+
fig = _run_corner(labels=labels, n=100)
72+
73+
# This should be x=a vs. y=c plotted in the lower left corner with both labels
74+
ax = corner.axis_from_param_indices(fig, 0, 2)
75+
assert ax.get_xlabel() == labels[0]
76+
assert ax.get_ylabel() == labels[2]
77+
78+
# This should be x=b vs. y=c, to the right of the previous with no y label
79+
ax = corner.axis_from_param_indices(fig, 1, 2)
80+
assert ax.get_xlabel() == labels[1]
81+
assert ax.get_ylabel() == ""
82+
83+
# This should be the histogram of c at the lower right
84+
ax = corner.axis_from_param_indices(fig, 2, 2)
85+
86+
# Some big number, probably 1584 depending on the seed?
87+
assert ax.get_ylim()[1] > 100
88+
89+
# ix > iy is hidden, which have ranges set to (0,1)
90+
ax = corner.axis_from_param_indices(fig, 2, 1)
91+
assert np.allclose(ax.get_xlim(), [0, 1])
92+
assert np.allclose(ax.get_ylim(), [0, 1])
93+
94+
with pytest.raises(ValueError):
95+
ax = corner.axis_from_param_indices(fig, 2, 4)
96+
97+
# Inverse
98+
for ix in range(len(labels)):
99+
for iy in range(ix + 1, len(labels)):
100+
i = corner.axis_from_param_indices(fig, ix, iy, return_axis=False)
101+
ix_i, iy_i = corner.param_indices_from_axis(fig, i)
102+
assert np.allclose([ix_i, iy_i], [ix, iy])
103+
104+
with pytest.raises(ValueError):
105+
_ = corner.param_indices_from_axis(fig, 100)
106+
107+
68108
@image_comparison(
69109
baseline_images=["basic_log"], remove_text=True, extensions=["png"]
70110
)

0 commit comments

Comments
 (0)