Skip to content

Commit b9333f9

Browse files
authored
Add gap/fill parameters to barplot and countplot (#3361)
* Add fill and gap parameters to barplot * Add params to docstring * Add gap/fill to catplot * Add gap/fill to countplot * Make default unfilled bar linewidth thicker * Add unit tests * Allow err_kws to override unfilled error bar color
1 parent 16c025d commit b9333f9

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

seaborn/categorical.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,9 @@ def plot_bars(
10541054
self,
10551055
aggregator,
10561056
dodge,
1057+
gap,
10571058
width,
1059+
fill,
10581060
color,
10591061
capsize,
10601062
err_kws,
@@ -1072,7 +1074,9 @@ def plot_bars(
10721074
if dodge and capsize is not None:
10731075
capsize = capsize / len(self._hue_map.levels)
10741076

1075-
err_kws.setdefault("color", ".26")
1077+
if not fill:
1078+
plot_kws.setdefault("linewidth", 1.5 * mpl.rcParams["lines.linewidth"])
1079+
10761080
err_kws.setdefault("linewidth", 1.5 * mpl.rcParams["lines.linewidth"])
10771081

10781082
for sub_vars, sub_data in self.iter_data(iter_vars,
@@ -1091,6 +1095,8 @@ def plot_bars(
10911095
agg_data["width"] = width * self._native_width
10921096
if dodge:
10931097
self._dodge(sub_vars, agg_data)
1098+
if gap:
1099+
agg_data["width"] *= 1 - gap
10941100

10951101
agg_data["edge"] = agg_data[self.orient] - agg_data["width"] / 2
10961102
self._invert_scale(ax, agg_data)
@@ -1106,17 +1112,22 @@ def plot_bars(
11061112
y=agg_data["edge"], width=agg_data["x"], height=agg_data["width"]
11071113
)
11081114

1109-
maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color
1115+
main_color = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color
11101116

11111117
# Set both color and facecolor for property cycle logic
1112-
kws["color"] = maincolor
1113-
kws["facecolor"] = maincolor
11141118
kws["align"] = "edge"
1119+
if fill:
1120+
kws.update(color=main_color, facecolor=main_color)
1121+
else:
1122+
kws.update(color=main_color, edgecolor=main_color, facecolor="none")
11151123

11161124
bar_func(**{**kws, **plot_kws})
11171125

11181126
if aggregator.error_method is not None:
1119-
self.plot_errorbars(ax, agg_data, capsize, err_kws.copy())
1127+
self.plot_errorbars(
1128+
ax, agg_data, capsize,
1129+
{"color": ".26" if fill else main_color, **err_kws}
1130+
)
11201131

11211132
self._configure_legend(ax, ax.fill_between)
11221133

@@ -2728,9 +2739,9 @@ def swarmplot(
27282739
def barplot(
27292740
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
27302741
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
2731-
orient=None, color=None, palette=None, saturation=.75, hue_norm=None, width=.8,
2732-
dodge="auto", native_scale=False, formatter=None, legend="auto", capsize=0,
2733-
err_kws=None, ci=deprecated, errcolor=deprecated, errwidth=deprecated,
2742+
orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None,
2743+
width=.8, dodge="auto", gap=0, native_scale=False, formatter=None, legend="auto",
2744+
capsize=0, err_kws=None, ci=deprecated, errcolor=deprecated, errwidth=deprecated,
27342745
ax=None, **kwargs,
27352746
):
27362747

@@ -2769,6 +2780,7 @@ def barplot(
27692780
hue_order = p._palette_without_hue_backcompat(palette, hue_order)
27702781
palette, hue_order = p._hue_backcompat(color, palette, hue_order)
27712782

2783+
saturation = saturation if fill else 1
27722784
p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation)
27732785
color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation)
27742786

@@ -2782,7 +2794,9 @@ def barplot(
27822794
aggregator=aggregator,
27832795
dodge=dodge,
27842796
width=width,
2797+
gap=gap,
27852798
color=color,
2799+
fill=fill,
27862800
capsize=capsize,
27872801
err_kws=err_kws,
27882802
plot_kws=kwargs,
@@ -2815,13 +2829,15 @@ def barplot(
28152829
{color}
28162830
{palette}
28172831
{saturation}
2832+
{fill}
28182833
{hue_norm}
28192834
{width}
2820-
{capsize}
28212835
{dodge}
2836+
{gap}
28222837
{native_scale}
28232838
{formatter}
28242839
{legend}
2840+
{capsize}
28252841
{err_kws}
28262842
{ci}
28272843
{errcolor}
@@ -3007,8 +3023,8 @@ def pointplot(
30073023

30083024
def countplot(
30093025
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
3010-
orient=None, color=None, palette=None, saturation=.75, hue_norm=None,
3011-
stat="count", width=.8, dodge="auto", native_scale=False, formatter=None,
3026+
orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None,
3027+
stat="count", width=.8, dodge="auto", gap=0, native_scale=False, formatter=None,
30123028
legend="auto", ax=None, **kwargs
30133029
):
30143030

@@ -3049,6 +3065,7 @@ def countplot(
30493065
hue_order = p._palette_without_hue_backcompat(palette, hue_order)
30503066
palette, hue_order = p._hue_backcompat(color, palette, hue_order)
30513067

3068+
saturation = saturation if fill else 1
30523069
p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation)
30533070
color = _default_color(ax.bar, hue, color, kwargs, saturation)
30543071

@@ -3068,7 +3085,9 @@ def countplot(
30683085
aggregator=aggregator,
30693086
dodge=dodge,
30703087
width=width,
3088+
gap=gap,
30713089
color=color,
3090+
fill=fill,
30723091
capsize=0,
30733092
err_kws={},
30743093
plot_kws=kwargs,
@@ -3409,19 +3428,22 @@ def catplot(
34093428
aggregator = EstimateAggregator(
34103429
estimator, errorbar, n_boot=n_boot, seed=seed
34113430
)
3412-
34133431
err_kws, capsize = p._err_kws_backcompat(
34143432
_normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D),
34153433
errcolor=kwargs.pop("errcolor", deprecated),
34163434
errwidth=kwargs.pop("errwidth", deprecated),
34173435
capsize=kwargs.pop("capsize", 0),
34183436
)
3437+
gap = kwargs.pop("gap", 0)
3438+
fill = kwargs.pop("fill", True)
34193439

34203440
p.plot_bars(
34213441
aggregator=aggregator,
34223442
dodge=dodge,
34233443
width=width,
3444+
gap=gap,
34243445
color=color,
3446+
fill=fill,
34253447
capsize=capsize,
34263448
err_kws=err_kws,
34273449
plot_kws=kwargs,
@@ -3441,11 +3463,16 @@ def catplot(
34413463
denom = 100 if stat == "percent" else 1
34423464
p.plot_data[count_axis] /= len(p.plot_data) / denom
34433465

3466+
gap = kwargs.pop("gap", 0)
3467+
fill = kwargs.pop("fill", True)
3468+
34443469
p.plot_bars(
34453470
aggregator=aggregator,
34463471
dodge=dodge,
34473472
width=width,
3473+
gap=gap,
34483474
color=color,
3475+
fill=fill,
34493476
capsize=0,
34503477
err_kws={},
34513478
plot_kws=kwargs,

tests/test_categorical.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,16 @@ def test_hue_dodged(self):
20172017
assert bar.get_width() == approx(0.8 / 2)
20182018
assert same_color(bar.get_facecolor(), f"C{i // 2}")
20192019

2020+
def test_gap(self):
2021+
2022+
x = ["a", "b", "a", "b"]
2023+
y = [1, 2, 3, 4]
2024+
hue = ["x", "x", "y", "y"]
2025+
2026+
ax = barplot(x=x, y=y, hue=hue, gap=.25)
2027+
for i, bar in enumerate(ax.patches):
2028+
assert bar.get_width() == approx(0.8 / 2 * .75)
2029+
20202030
def test_hue_undodged(self):
20212031

20222032
x = ["a", "b", "a", "b"]
@@ -2051,6 +2061,17 @@ def test_hue_norm(self):
20512061
assert colors[1] != colors[2]
20522062
assert colors[2] == colors[3]
20532063

2064+
def test_fill(self):
2065+
2066+
x = ["a", "b", "a", "b"]
2067+
y = [1, 2, 3, 4]
2068+
hue = ["x", "x", "y", "y"]
2069+
2070+
ax = barplot(x=x, y=y, hue=hue, fill=False)
2071+
for i, bar in enumerate(ax.patches):
2072+
assert same_color(bar.get_edgecolor(), f"C{i // 2}")
2073+
assert same_color(bar.get_facecolor(), (0, 0, 0, 0))
2074+
20542075
def test_xy_native_scale(self):
20552076

20562077
x, y = [2, 4, 8], [1, 2, 3]
@@ -2261,11 +2282,12 @@ def test_bar_kwargs(self):
22612282
assert bar.get_facecolor() == kwargs["facecolor"]
22622283
assert bar.get_rasterized() == kwargs["rasterized"]
22632284

2264-
def test_err_kws(self):
2285+
@pytest.mark.parametrize("fill", [True, False])
2286+
def test_err_kws(self, fill):
22652287

22662288
x, y = ["a", "b", "c"], [1, 2, 3]
22672289
err_kws = dict(color=(1, 1, .5, .5), linewidth=5)
2268-
ax = barplot(x=x, y=y, err_kws=err_kws)
2290+
ax = barplot(x=x, y=y, fill=fill, err_kws=err_kws)
22692291
for line in ax.lines:
22702292
assert line.get_color() == err_kws["color"]
22712293
assert line.get_linewidth() == err_kws["linewidth"]
@@ -2284,7 +2306,7 @@ def test_err_kws(self):
22842306
dict(data=None, x="s", y="y", hue="a"),
22852307
dict(data="long", x="a", y="y", hue="s"),
22862308
dict(data="long", x="a", y="y", units="c"),
2287-
dict(data="null", x="a", y="y", hue="a"),
2309+
dict(data="null", x="a", y="y", hue="a", gap=.1, fill=False),
22882310
dict(data="long", x="s", y="y", hue="a", native_scale=True),
22892311
dict(data="long", x="d", y="y", hue="a", native_scale=True),
22902312
dict(data="long", x="a", y="y", errorbar=("pi", 50)),

0 commit comments

Comments
 (0)