Skip to content

Commit 26db3ca

Browse files
authored
Merge pull request #152 from benoit9126/master
First implementation of best location for legends
2 parents b72a6ed + f4f2bc9 commit 26db3ca

File tree

5 files changed

+157
-25
lines changed

5 files changed

+157
-25
lines changed

matplotlib2tikz/legend.py

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# -*- coding: utf-8 -*-
22
#
33
import warnings
4+
5+
import numpy
6+
47
from . import color as mycol
58

69

@@ -19,52 +22,126 @@ def draw_legend(data, obj):
1922
# Get the location.
2023
# http://matplotlib.org/api/legend_api.html
2124
pad = 0.03
22-
if obj._loc == 0:
25+
loc = obj._loc
26+
if loc == 0:
2327
# best
24-
print(
25-
'Legend location "best" not yet implemented, '
26-
'choosing "upper right" instead.'
27-
)
28-
position = None
29-
anchor = None
30-
elif obj._loc == 1:
28+
# Create a renderer
29+
from matplotlib.backends import backend_agg
30+
renderer = backend_agg.RendererAgg(width=obj.figure.get_figwidth(),
31+
height=obj.figure.get_figheight(),
32+
dpi=obj.figure.dpi)
33+
34+
# Rectangles of the legend and of the axes
35+
# Lower left and upper right points
36+
x0_legend, x1_legend = obj._legend_box \
37+
.get_window_extent(renderer).get_points()
38+
x0_axes, x1_axes = obj.axes.get_window_extent(renderer).get_points()
39+
dimension_legend = x1_legend - x0_legend
40+
dimension_axes = x1_axes - x0_axes
41+
42+
# To determine the actual position of the legend, check which corner
43+
# (or center) of the legend is closest to the corresponding corner
44+
# (or center) of the axes box.
45+
# 1. Key points of the legend
46+
lower_left_legend = x0_legend
47+
lower_right_legend = numpy.array([x1_legend[0], x0_legend[1]],
48+
dtype=numpy.float_)
49+
upper_left_legend = numpy.array([x0_legend[0], x1_legend[1]],
50+
dtype=numpy.float_)
51+
upper_right_legend = x1_legend
52+
center_legend = x0_legend + dimension_legend / 2.
53+
center_left_legend = numpy.array(
54+
[x0_legend[0], x0_legend[1] + dimension_legend[1] / 2.],
55+
dtype=numpy.float_)
56+
center_right_legend = numpy.array(
57+
[x1_legend[0], x0_legend[1] + dimension_legend[1] / 2.],
58+
dtype=numpy.float_)
59+
lower_center_legend = numpy.array(
60+
[x0_legend[0] + dimension_legend[0] / 2., x0_legend[1]],
61+
dtype=numpy.float_)
62+
upper_center_legend = numpy.array(
63+
[x0_legend[0] + dimension_legend[0] / 2., x1_legend[1]],
64+
dtype=numpy.float_)
65+
66+
# 2. Key points of the axes
67+
lower_left_axes = x0_axes
68+
lower_right_axes = numpy.array([x1_axes[0], x0_axes[1]],
69+
dtype=numpy.float_)
70+
upper_left_axes = numpy.array([x0_axes[0], x1_axes[1]],
71+
dtype=numpy.float_)
72+
upper_right_axes = x1_axes
73+
center_axes = x0_axes + dimension_axes / 2.
74+
center_left_axes = numpy.array(
75+
[x0_axes[0], x0_axes[1] + dimension_axes[1] / 2.],
76+
dtype=numpy.float_)
77+
center_right_axes = numpy.array(
78+
[x1_axes[0], x0_axes[1] + dimension_axes[1] / 2.],
79+
dtype=numpy.float_)
80+
lower_center_axes = numpy.array(
81+
[x0_axes[0] + dimension_axes[0] / 2., x0_axes[1]],
82+
dtype=numpy.float_)
83+
upper_center_axes = numpy.array(
84+
[x0_axes[0] + dimension_axes[0] / 2., x1_axes[1]],
85+
dtype=numpy.float_)
86+
87+
# 3. Compute the distances between comparable points.
88+
distances = {
89+
1: upper_right_axes - upper_right_legend, # upper right
90+
2: upper_left_axes - upper_left_legend, # upper left
91+
3: lower_left_axes - lower_left_legend, # lower left
92+
4: lower_right_axes - lower_right_legend, # lower right
93+
# 5:, Not Implemented # right
94+
6: center_left_axes - center_left_legend, # center left
95+
7: center_right_axes - center_right_legend, # center right
96+
8: lower_center_axes - lower_center_legend, # lower center
97+
9: upper_center_axes - upper_center_legend, # upper center
98+
10: center_axes - center_legend # center
99+
}
100+
for k, v in distances.items():
101+
distances[k] = numpy.linalg.norm(v, ord=2)
102+
103+
# 4. Take the shortest distance between key points as the final
104+
# location
105+
loc = min(distances, key=distances.get)
106+
107+
if loc == 1:
31108
# upper right
32109
position = None
33110
anchor = None
34-
elif obj._loc == 2:
111+
elif loc == 2:
35112
# upper left
36113
position = [pad, 1.0 - pad]
37114
anchor = 'north west'
38-
elif obj._loc == 3:
115+
elif loc == 3:
39116
# lower left
40117
position = [pad, pad]
41118
anchor = 'south west'
42-
elif obj._loc == 4:
119+
elif loc == 4:
43120
# lower right
44121
position = [1.0 - pad, pad]
45122
anchor = 'south east'
46-
elif obj._loc == 5:
123+
elif loc == 5:
47124
# right
48125
position = [1.0 - pad, 0.5]
49-
anchor = 'west'
50-
elif obj._loc == 6:
126+
anchor = 'east'
127+
elif loc == 6:
51128
# center left
52129
position = [3 * pad, 0.5]
53-
anchor = 'east'
54-
elif obj._loc == 7:
130+
anchor = 'west'
131+
elif loc == 7:
55132
# center right
56133
position = [1.0 - 3 * pad, 0.5]
57-
anchor = 'west'
58-
elif obj._loc == 8:
134+
anchor = 'east'
135+
elif loc == 8:
59136
# lower center
60137
position = [0.5, 3 * pad]
61138
anchor = 'south'
62-
elif obj._loc == 9:
139+
elif loc == 9:
63140
# upper center
64141
position = [0.5, 1.0 - 3 * pad]
65142
anchor = 'north'
66143
else:
67-
assert obj._loc == 10
144+
assert loc == 10
68145
# center
69146
position = [0.5, 0.5]
70147
anchor = 'center'
@@ -100,7 +177,7 @@ def draw_legend(data, obj):
100177
if alignment != childAlignment:
101178
warnings.warn(
102179
'Varying horizontal alignments in the legend. Using default.'
103-
)
180+
)
104181
alignment = None
105182
break
106183

test/testfunctions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
'fillstyle',
1616
'heat',
1717
'image_plot',
18+
'legend_best_location',
1819
'legends2',
1920
'legends',
2021
'line_collection',
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
desc = 'Legend best positions'
4+
phash = '879d991d2c877c1c'
5+
6+
7+
def plot():
8+
from matplotlib import pyplot as plt
9+
import numpy as np
10+
11+
fig, ax = plt.subplots(3, 3, sharex='col', sharey='row')
12+
axes = [ax[i][j] for i in range(len(ax)) for j in range(len(ax[i]))]
13+
t = np.arange(0.0, 2.0 * np.pi, 0.1)
14+
15+
# Legend best location is "upper right"
16+
l, = axes[0].plot(t, np.cos(t) * np.exp(-t), linewidth=0.5)
17+
axes[0].legend((l,), ('UR',), loc=0)
18+
19+
# Legend best location is "upper left"
20+
l, = axes[1].plot(t, np.cos(t) * np.exp(0.15 * t), linewidth=0.5)
21+
axes[1].legend((l,), ('UL',), loc=0)
22+
23+
# Legend best location is "lower left"
24+
l, = axes[2].plot(t, np.cos(5. * t) + 1, linewidth=0.5)
25+
axes[2].legend((l,), ('LL',), loc=0)
26+
27+
# Legend best location is "lower right"
28+
l, = axes[3].plot(t, 2 * np.cos(5. * t) * np.exp(-0.5 * t) + 0.2 * t,
29+
linewidth=0.5)
30+
axes[3].legend((l,), ('LR',), loc=0)
31+
32+
# Legend best location is "center left"
33+
l, = axes[4].plot(t[30:], 2 * np.cos(10 * t[30:]), linewidth=0.5)
34+
l2, l3 = axes[4].plot(t, -1.5 * np.ones_like(t), t, 1.5 * np.ones_like(t))
35+
axes[4].legend((l,), ('CL',), loc=0)
36+
37+
# Legend best location is "center right"
38+
l, = axes[5].plot(t[:30], 2 * np.cos(10 * t[:30]), linewidth=0.5)
39+
l2, l3 = axes[5].plot(t, -1.5 * np.ones_like(t), t, 1.5 * np.ones_like(t))
40+
axes[5].legend((l,), ('CR',), loc=0)
41+
42+
# Legend best location is "lower center"
43+
l, = axes[6].plot(t, -3 * np.cos(t) * np.exp(-0.1 * t), linewidth=0.5)
44+
axes[6].legend((l,), ('LC',), loc=0)
45+
46+
# Legend best location is "upper center"
47+
l, = axes[7].plot(t, 3 * np.cos(t) * np.exp(-0.1 * t), linewidth=0.5)
48+
axes[7].legend((l,), ('UC',), loc=0)
49+
50+
# Legend best location is "center"
51+
l, l1 = axes[8].plot(t[:10], 2 * np.cos(10 * t[:10]), t[-10:],
52+
2 * np.cos(10 * t[-10:]), linewidth=0.5)
53+
l2, l3 = axes[8].plot(t, -2 * np.ones_like(t), t, 2 * np.ones_like(t))
54+
axes[8].legend((l,), ('C',), loc=0)
55+
56+
return fig

test/testfunctions/legends2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# -*- coding: utf-8 -*-
22
#
33
desc = 'Multiple legend positions'
4-
# phash = 'd558f444f0542bbb'
5-
phash = '55d45cd47ad4812f'
4+
phash = '6b447a5a62d4952f'
65

76

87
def plot():

test/testfunctions/text_overlay.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# -*- coding: utf-8 -*-
22
#
33
desc = 'Regular plot with overlay text'
4-
# phash = '770b23744b93c68d'
5-
phash = '37092b3649d3f64c'
4+
phash = '370da93449d3f64c'
65

76

87
def plot():

0 commit comments

Comments
 (0)