Skip to content

Commit ba7276b

Browse files
committed
Reimplement resizing of plot compositions
1 parent 9594278 commit ba7276b

File tree

10 files changed

+145
-37
lines changed

10 files changed

+145
-37
lines changed

plotnine/_mpl/gridspec.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,20 @@ def bbox(self):
243243
"""
244244
return TransformedBbox(self.bbox_relative, self.figure.transSubfigure)
245245

246+
@property
247+
def width(self) -> float:
248+
"""
249+
Width of bbox in figure space
250+
"""
251+
return self.bbox_relative.width
252+
253+
@property
254+
def height(self) -> float:
255+
"""
256+
Height of bbox in figure space
257+
"""
258+
return self.bbox_relative.height
259+
246260
def to_transform(self) -> Transform:
247261
"""
248262
Return transform of this gridspec

plotnine/_mpl/layout_manager/_layout_tree.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -258,55 +258,53 @@ def right_most_spaces(self) -> list[right_spaces]:
258258
@property
259259
def panel_width(self) -> float:
260260
"""
261-
A representative width for panels of the nodes
261+
A width of all panels in this composition
262262
"""
263263
return sum(self.panel_widths)
264264

265265
@property
266266
def panel_height(self) -> float:
267267
"""
268-
A representative height for panels of the nodes
268+
A height of all panels in this composition
269269
"""
270270
return sum(self.panel_heights)
271271

272272
@property
273273
def plot_width(self) -> float:
274274
"""
275-
A representative for width for plots of the nodes
275+
A width of all plots in this tree/composition
276276
"""
277-
return sum(self.plot_widths)
277+
return self.gridspec.width
278278

279279
@property
280280
def plot_height(self) -> float:
281281
"""
282-
A representative for height for plots of the nodes
282+
A height of all plots in this tree/composition
283283
"""
284-
return sum(self.plot_heights)
284+
return self.gridspec.height
285285

286286
@property
287287
def panel_widths(self) -> Sequence[float]:
288288
"""
289289
Widths [figure space] of the panels along horizontal dimension
290-
291-
For each column, the effective panel width is the width of the
292-
shortest panel.
293290
"""
294-
n = self.ncol
291+
# This method is used after aligning the panels. Therefore, the
292+
# wides panel_width (i.e. max()) is the good representative width
293+
# of the column.
294+
w = self.plot_width / self.ncol
295295
return [
296-
max([node.panel_width if node else 1 / n for node in col])
296+
max(node.panel_width for node in col if node) if any(col) else w
297297
for col in self.grid.iter_cols()
298298
]
299299

300300
@property
301301
def panel_heights(self) -> Sequence[float]:
302302
"""
303303
Heights [figure space] of the panels along vertical dimension
304-
305-
For each row, the representative height is that of the shortest panel.
306304
"""
307-
n = self.nrow
305+
h = self.plot_height / self.nrow
308306
return [
309-
max([node.panel_height if node else 1 / n for node in row])
307+
max([node.panel_height for node in row if node]) if any(row) else h
310308
for row in self.grid.iter_rows()
311309
]
312310

@@ -315,11 +313,11 @@ def plot_widths(self) -> Sequence[float]:
315313
"""
316314
Widths [figure space] of the plots along horizontal dimension
317315
318-
For each column, the representative width is that of the shortest plot.
316+
For each column, the representative width is that of the widest plot.
319317
"""
320-
n = self.ncol
318+
w = self.gridspec.width / self.ncol
321319
return [
322-
max([node.plot_width if node else 1 / n for node in col])
320+
max([node.plot_width if node else w for node in col])
323321
for col in self.grid.iter_cols()
324322
]
325323

@@ -328,14 +326,32 @@ def plot_heights(self) -> Sequence[float]:
328326
"""
329327
Heights [figure space] of the plots along vertical dimension
330328
331-
For each row, the representative height is that of the shortest plot.
329+
For each row, the representative height is that of the tallest plot.
332330
"""
333-
n = self.nrow
331+
h = self.gridspec.height / self.nrow
334332
return [
335-
max([node.plot_height if node else 1 / n for node in row])
333+
max([node.plot_height if node else h for node in row])
336334
for row in self.grid.iter_rows()
337335
]
338336

337+
@property
338+
def panel_width_ratios(self) -> Sequence[float]:
339+
"""
340+
The relative widths of the panels in the composition
341+
342+
These are normalised to have a mean = 1.
343+
"""
344+
return cast("Sequence[float]", self.cmp._layout.widths)
345+
346+
@property
347+
def panel_height_ratios(self) -> Sequence[float]:
348+
"""
349+
The relative heights of the panels in the composition
350+
351+
These are normalised to have a mean = 1.
352+
"""
353+
return cast("Sequence[float]", self.cmp._layout.heights)
354+
339355
def bottom_spaces_in_row(self, r: int) -> list[bottom_spaces]:
340356
spaces: list[bottom_spaces] = []
341357
for node in self.grid[r, :]:
@@ -509,21 +525,22 @@ def align_axis_titles(self):
509525
tree.align_axis_titles()
510526

511527
def resize_widths(self):
512-
n = self.ncol
513-
resize_ratios = np.array(self.cmp._layout.widths)
514-
base_panel_widths = np.ones(n) * self.panel_width
515-
scaled_panel_widths = base_panel_widths * resize_ratios
528+
# The scaling calcuation to get the new panel width is
529+
# straight-forward because the ratios have a mean of 1.
530+
# So the multiplication preserves the total panel width.
531+
new_panel_widths = np.mean(self.panel_widths) * np.array(
532+
self.panel_width_ratios
533+
)
516534
non_panel_space = np.array(self.plot_widths) - self.panel_widths
517-
new_plot_widths = scaled_panel_widths + non_panel_space
535+
new_plot_widths = new_panel_widths + non_panel_space
518536
width_ratios = new_plot_widths / new_plot_widths.max()
519537
self.gridspec.set_width_ratios(width_ratios)
520538

521539
def resize_heights(self):
522-
n = self.nrow
523-
resize_ratios = np.array(self.cmp._layout.heights)
524-
base_panel_heights = np.ones(n) * self.panel_height
525-
scaled_panel_heights = base_panel_heights * resize_ratios
540+
new_panel_heights = np.mean(self.panel_heights) * np.array(
541+
self.panel_height_ratios
542+
)
526543
non_panel_space = np.array(self.plot_heights) - self.panel_heights
527-
new_plot_heights = scaled_panel_heights + non_panel_space
544+
new_plot_heights = new_panel_heights + non_panel_space
528545
height_ratios = new_plot_heights / new_plot_heights.max()
529546
self.gridspec.set_height_ratios(height_ratios)

plotnine/_mpl/layout_manager/_spaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,14 +924,14 @@ def plot_width(self) -> float:
924924
"""
925925
Width [figure dimensions] of the whole plot
926926
"""
927-
return self.plot._gridspec.bbox_relative.width
927+
return float(self.plot._gridspec.width)
928928

929929
@property
930930
def plot_height(self) -> float:
931931
"""
932932
Height [figure dimensions] of the whole plot
933933
"""
934-
return self.plot._gridspec.bbox_relative.height
934+
return float(self.plot._gridspec.height)
935935

936936
@property
937937
def panel_width(self) -> float:

plotnine/composition/_plot_layout.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def repeat(seq: Sequence[float], n: int) -> list[float]:
8686
return [val for _, val in zip(range(n), cycle(seq))]
8787

8888

89-
def normalise(seq) -> list[float]:
90-
total = sum(seq)
91-
return [x / total for x in seq]
89+
def normalise(seq: Sequence[float]) -> list[float]:
90+
"""
91+
Normalise seq so that the mean is 1
92+
"""
93+
mean = sum(seq) / len(seq)
94+
if mean == 0:
95+
raise ValueError("Cannot rescale: mean is zero")
96+
return [x / mean for x in seq]
13.7 KB
Loading
7.55 KB
Loading
9.05 KB
Loading
7.64 KB
Loading
19.8 KB
Loading

tests/test_plot_composition.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from plotnine import element_text, facet_grid, facet_wrap, theme, theme_gray
1+
from plotnine import (
2+
element_text,
3+
facet_grid,
4+
facet_wrap,
5+
labs,
6+
theme,
7+
theme_gray,
8+
)
29
from plotnine._utils.yippie import geom as g
310
from plotnine._utils.yippie import legend, plot, rotate, tag
411
from plotnine.composition._plot_layout import plot_layout
@@ -179,3 +186,68 @@ def test_plot_layout_nested_resize():
179186

180187
p = (((p1 | p2) + ws) / ((p3 | p4) + ws)) + hs
181188
assert p == "plot_layout_nested_resize"
189+
190+
191+
def test_plot_layout_extra_cols():
192+
p1 = plot.red
193+
p2 = plot.green
194+
p3 = plot.blue
195+
p = (p1 | p2 | p3) + plot_layout(ncol=5)
196+
assert p == "plot_layout_extra_cols"
197+
198+
199+
def test_plot_layout_extra_col_width():
200+
# An extra column is extactly panel_width wide (no margin)
201+
# By stacking two rows, where one has an extra column, we can
202+
# confirm the size.
203+
p1 = plot.red
204+
p2 = plot.green
205+
p3 = plot.blue
206+
p4 = plot.yellow + labs(y="") + theme(plot_margin=0)
207+
208+
c1 = (p1 | p2 | p3) + plot_layout(ncol=4)
209+
c2 = p1 | p2 | p3 | p4
210+
p = c1 / c2
211+
assert p == "plot_layout_extra_col_width"
212+
213+
214+
def test_plot_layout_extra_rows():
215+
p1 = plot.red
216+
p2 = plot.green
217+
p3 = plot.blue
218+
p = (p1 / p2 / p3) + plot_layout(nrow=5)
219+
assert p == "plot_layout_extra_rows"
220+
221+
222+
def test_plot_layout_extra_row_width():
223+
# An extra row is extactly panel_width wide (no margin)
224+
# By stacking two rows, where one has an extra row, we can
225+
# confirm the size.
226+
p1 = plot.red
227+
p2 = plot.green
228+
p3 = plot.blue
229+
p4 = plot.yellow + labs(x="", title="") + theme(plot_margin=0)
230+
231+
c1 = (p1 / p2 / p3) + plot_layout(nrow=4)
232+
c2 = p1 / p2 / p3 / p4
233+
p = c1 | c2
234+
assert p == "plot_layout_extra_row_width"
235+
236+
237+
def test_wrap_complicated():
238+
p1 = plot.red
239+
p2 = plot.green
240+
p3 = plot.blue + g.points
241+
p4 = plot.yellow
242+
p5 = plot.cyan
243+
p6 = plot.orange
244+
p7 = plot.purple + g.cols
245+
246+
c1 = (p1 + p2 + p3) + plot_layout(ncol=2)
247+
c2 = (p4 + p5 + p6 + p7) + plot_layout(ncol=4, widths=[1, 2, 4, 8])
248+
249+
# The top composition has two rows and the bottom one has one.
250+
# With [1, 2] height ratios, the panels in each row should have
251+
# the same height.
252+
p = (c1 / c2) + plot_layout(heights=[2, 1])
253+
assert p == "wrap_complicated"

0 commit comments

Comments
 (0)