Skip to content

Commit ec3bcb6

Browse files
Added tests for legend_loc
1 parent fd992a0 commit ec3bcb6

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

pandas/tests/plotting/frame/test_frame_legend.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_check_legend_labels,
1212
_check_legend_marker,
1313
_check_text_labels,
14+
_check_plot_works,
1415
)
1516

1617
mpl = pytest.importorskip("matplotlib")
@@ -260,3 +261,55 @@ def test_missing_markers_legend_using_style(self):
260261

261262
_check_legend_labels(ax, labels=["A", "B", "C"])
262263
_check_legend_marker(ax, expected_markers=[".", ".", "."])
264+
265+
def test_no_legend_when_legend_false_legend_loc(self):
266+
df = DataFrame({"a": [1, 2, 3]})
267+
ax = _check_plot_works(df.plot, legend=False, legend_loc="upper right")
268+
assert ax.get_legend() is None
269+
270+
def test_df_legend_labels_secondary_y_legend_loc(self):
271+
pytest.importorskip("scipy")
272+
df = DataFrame(np.random.default_rng(2).random((3, 3)), columns=["a", "b", "c"])
273+
df2 = DataFrame(
274+
np.random.default_rng(2).random((3, 3)), columns=["d", "e", "f"]
275+
)
276+
df3 = DataFrame(
277+
np.random.default_rng(2).random((3, 3)), columns=["g", "h", "i"]
278+
)
279+
280+
# preserves "best" behaviour if no legend_loc specified
281+
ax = df.plot(legend=True, secondary_y="b", legend_loc=None)
282+
_check_legend_labels(ax, labels=["a", "b (right)", "c"])
283+
ax = df2.plot(legend=False, ax=ax)
284+
_check_legend_labels(ax, labels=["a", "b (right)", "c"])
285+
ax = df3.plot(kind="bar", legend=True, secondary_y="h", legend_loc="upper right", ax=ax)
286+
_check_legend_labels(ax, labels=["a", "b (right)", "c", "g", "h (right)", "i"])
287+
288+
@pytest.mark.parametrize(
289+
"kind, labels",
290+
[
291+
("line", ["a", "b"]),
292+
("bar", ["a", "b"]),
293+
("scatter", ["b"]), # scatter(x="a", y="b") usually labels the y series
294+
])
295+
def test_across_kinds_legend_loc(self, kind, labels):
296+
df = DataFrame({"a": [1, 2, 3], "b": [3, 2, 1]})
297+
if kind == "scatter":
298+
ax = _check_plot_works(
299+
df.plot,
300+
kind="scatter",
301+
x="a",
302+
y="b",
303+
label="b",
304+
legend=True,
305+
legend_loc="upper right"
306+
)
307+
else:
308+
ax = _check_plot_works(
309+
df.plot,
310+
kind=kind,
311+
legend=True,
312+
legend_loc="upper right"
313+
)
314+
_check_legend_labels(ax, labels=labels)
315+
assert ax.get_legend() is not None

0 commit comments

Comments
 (0)