Skip to content

Commit 5282cb5

Browse files
committed
better colorbar recognition heuristic
1 parent 13a8734 commit 5282cb5

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/tikzplotlib/_axes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -630,13 +630,19 @@ def _get_ticks(data, xy, ticks, ticklabels):
630630

631631
def _is_colorbar_heuristic(obj):
632632
"""Find out if the object is in fact a color bar."""
633-
# TODO come up with something more accurate here
633+
# Not sure if these properties are always present
634+
if hasattr(obj, "_colorbar") or hasattr(obj, "_colorbar_info"):
635+
return True
636+
637+
# TODO come up with something more accurate here. See
638+
# <https://discourse.matplotlib.org/t/find-out-if-an-axes-object-is-a-colorbar/22563>
634639
# Might help:
635640
# TODO Are the colorbars exactly the l.collections.PolyCollection's?
641+
636642
try:
637643
aspect = float(obj.get_aspect())
638644
except ValueError:
639-
# e.g., aspect == 'equal'
645+
# e.g., aspect in ['equal', 'auto']
640646
return False
641647

642648
# Assume that something is a colorbar if and only if the ratio is above 5.0
@@ -646,10 +652,10 @@ def _is_colorbar_heuristic(obj):
646652
#
647653
# plt.colorbar(im, aspect=5)
648654
#
649-
limit_ratio = 5.0
655+
threshold_ratio = 5.0
650656

651-
return (aspect >= limit_ratio and len(obj.get_xticks()) == 0) or (
652-
aspect <= 1.0 / limit_ratio and len(obj.get_yticks()) == 0
657+
return (aspect >= threshold_ratio and len(obj.get_xticks()) == 0) or (
658+
aspect <= 1.0 / threshold_ratio and len(obj.get_yticks()) == 0
653659
)
654660

655661

tests/test_logplot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ def plot():
77
ax = fig.add_subplot(1, 1, 1)
88
ax.semilogy(a, color="blue", lw=0.25)
99

10-
plt.grid(b=True, which="major", color="g", linestyle="-", linewidth=0.25)
11-
plt.grid(b=True, which="minor", color="r", linestyle="--", linewidth=0.5)
10+
plt.grid(visible=True, which="major", color="g", linestyle="-", linewidth=0.25)
11+
plt.grid(visible=True, which="minor", color="r", linestyle="--", linewidth=0.5)
1212
return fig
1313

1414

0 commit comments

Comments
 (0)