Skip to content

Commit 5bac1e1

Browse files
committed
Do not overwrite labels set with labs
... unless if using labs to overwrite.
1 parent 097192d commit 5bac1e1

File tree

6 files changed

+65
-40
lines changed

6 files changed

+65
-40
lines changed

doc/changelog.qmd

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
---
22
title: Changelog
33
---
4+
5+
## v0.15.1
6+
(not-yet-released)
7+
8+
### Bug Fixes
9+
10+
- Fixed labels set with the `labs` call so that they are only ever overwritten
11+
by a call to `labs` or setting the `name` of a scale. Previously, if a
12+
global mapping was added after `labs`, it over-wrote the previously set labels.
13+
14+
```python
15+
ggplot(mtcars) + labs(x="x title", y="y title") + geom_point() + aes("wt", "mpg")
16+
```
17+
18+
The labels will be "x title" and "y title".
19+
20+
421
## v0.15.0
522

623
(2025-06-15)

plotnine/ggplot.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from .facets.layout import Layout
3737
from .geoms.geom_blank import geom_blank
3838
from .guides.guides import guides
39-
from .iapi import mpl_save_view
39+
from .iapi import labels_view, mpl_save_view
4040
from .layer import Layers
41-
from .mapping.aes import aes, make_labels
41+
from .mapping.aes import aes
4242
from .options import get_option
4343
from .scales.scales import Scales
4444
from .themes.theme import theme, theme_get
@@ -55,7 +55,6 @@
5555
from plotnine.composition import Compose
5656
from plotnine.coords.coord import coord
5757
from plotnine.facets.facet import facet
58-
from plotnine.layer import layer
5958
from plotnine.typing import DataLike
6059

6160
class PlotAddable(Protocol):
@@ -118,7 +117,7 @@ def __init__(
118117
self.data = data
119118
self.mapping = mapping if mapping is not None else aes()
120119
self.facet: facet = facet_null()
121-
self.labels = make_labels(self.mapping)
120+
self.labels = labels_view()
122121
self.layers = Layers()
123122
self.guides = guides()
124123
self.scales = Scales()
@@ -301,10 +300,7 @@ def draw(self, *, show: bool = False) -> Figure:
301300
from ._mpl.layout_manager import PlotnineLayoutEngine
302301

303302
with plot_context(self, show=show):
304-
if not hasattr(self, "figure"):
305-
self._create_figure()
306-
figure = self.figure
307-
303+
figure = self._setup()
308304
self._build()
309305

310306
# setup
@@ -327,6 +323,16 @@ def draw(self, *, show: bool = False) -> Figure:
327323

328324
return figure
329325

326+
def _setup(self) -> Figure:
327+
"""
328+
Setup this instance for the building process
329+
"""
330+
if not hasattr(self, "figure"):
331+
self._create_figure()
332+
333+
self.labels.add_defaults(self.mapping.labels)
334+
return self.figure
335+
330336
def _create_figure(self):
331337
"""
332338
Create gridspec for the panels
@@ -548,21 +554,6 @@ def _save_filename(self, ext: str) -> Path:
548554
hash_token = abs(self.__hash__())
549555
return Path(f"plotnine-save-{hash_token}.{ext}")
550556

551-
def _update_labels(self, layer: layer):
552-
"""
553-
Update label data for the ggplot
554-
555-
Parameters
556-
----------
557-
layer : layer
558-
New layer that has just been added to the ggplot
559-
object.
560-
"""
561-
mapping = make_labels(layer.mapping)
562-
default = make_labels(layer.stat.DEFAULT_AES)
563-
mapping.add_defaults(default)
564-
self.labels.add_defaults(mapping)
565-
566557
def save_helper(
567558
self: ggplot,
568559
filename: Optional[str | Path | BytesIO] = None,

plotnine/layer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ._utils import array_kind, check_required_aesthetics, ninteraction
1010
from .exceptions import PlotnineError
11-
from .mapping.aes import NO_GROUP, SCALED_AESTHETICS, aes
11+
from .mapping.aes import NO_GROUP, SCALED_AESTHETICS, aes, make_labels
1212
from .mapping.evaluation import evaluate, stage
1313

1414
if typing.TYPE_CHECKING:
@@ -409,6 +409,13 @@ def finish_statistics(self):
409409
"""
410410
self.stat.finish_layer(self.data)
411411

412+
def update_labels(self, plot: ggplot):
413+
"""
414+
Update label data for the ggplot from the mappings in this layer
415+
"""
416+
plot.labels.add_defaults(self.mapping.labels)
417+
plot.labels.add_defaults(make_labels(self.stat.DEFAULT_AES))
418+
412419

413420
class Layers(List[layer]):
414421
"""
@@ -509,7 +516,7 @@ def finish_statistics(self):
509516

510517
def update_labels(self, plot: ggplot):
511518
for l in self:
512-
plot._update_labels(l)
519+
l.update_labels(plot)
513520

514521

515522
def add_group(data: pd.DataFrame) -> pd.DataFrame:

plotnine/mapping/aes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,15 @@ def __radd__(self, other):
290290
"""
291291
self = deepcopy(self)
292292
other.mapping.update(self)
293-
other.labels.update(make_labels(self))
294293
return other
295294

295+
@property
296+
def labels(self) -> labels_view:
297+
"""
298+
The labels for this mapping
299+
"""
300+
return make_labels(self)
301+
296302
def copy(self):
297303
return aes(**self)
298304

tests/test_aes.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
geom_crossbar,
1010
geom_point,
1111
ggplot,
12+
labs,
1213
scale_x_log10,
1314
scale_y_log10,
1415
stage,
@@ -26,14 +27,6 @@
2627
)
2728

2829

29-
data = pd.DataFrame(
30-
{
31-
"x": pd.Categorical(["b", "d", "c", "a"], ordered=True),
32-
"y": [1, 2, 3, 4],
33-
}
34-
)
35-
36-
3730
def test_reorder():
3831
p = (
3932
ggplot(data, aes("reorder(x, y)", "y", fill="reorder(x, y)"))
@@ -50,15 +43,17 @@ def test_reorder_index():
5043

5144

5245
def test_labels_series():
53-
p = ggplot(data, aes(x=data.x, y=data.y)) + geom_col()
46+
p = ggplot(data, aes(x=data.x, y=data.y)) + geom_col() + labs(y="yy")
47+
p.draw()
5448
assert p.labels.x == "x"
55-
assert p.labels.y == "y"
49+
assert p.labels.y == "yy"
5650

5751

5852
def test_labels_lists():
59-
p = ggplot(data, aes(x=[1, 2, 3], y=[1, 2, 3])) + geom_col()
60-
assert p.labels.x is None
61-
assert p.labels.y is None
53+
p = ggplot(data, aes(x=[0, 1, 2, 3], y=range(4))) + geom_col()
54+
p.draw()
55+
assert p.labels.x == ""
56+
assert p.labels.y == ""
6257

6358

6459
def test_irregular_shapes():

tests/test_ggplot_internals.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,21 @@ def test_add_aes():
222222
data = pd.DataFrame({"var1": [1, 2, 3, 4], "var2": 2})
223223
p = ggplot(data) + geom_point()
224224
p += aes("var1", "var2")
225-
225+
p.draw()
226226
assert p.mapping == aes("var1", "var2")
227227
assert p.labels.x == "var1"
228228
assert p.labels.y == "var2"
229229

230230

231+
def test_add_labs():
232+
data = pd.DataFrame({"var1": [1, 2, 3, 4], "var2": 2})
233+
p = ggplot(data) + geom_point() + labs(x="x title")
234+
p += aes("var1", "var2")
235+
p.draw()
236+
assert p.labels.x == "x title"
237+
assert p.labels.y == "var2"
238+
239+
231240
def test_nonzero_indexed_data():
232241
data = pd.DataFrame(
233242
{98: {"blip": 0, "blop": 1}, 99: {"blip": 1, "blop": 3}}

0 commit comments

Comments
 (0)