Skip to content

Commit 15ff0b4

Browse files
implemented recursive formulation for cleanfigure
1 parent 72cde2a commit 15ff0b4

File tree

2 files changed

+86
-52
lines changed

2 files changed

+86
-52
lines changed

test/test_cleanfigure.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
RC_PARAMS = {"figure.figsize": [5, 5], "figure.dpi": 220, "pgf.rcfonts": False}
88

99

10+
def test_recursive_cleanfigure():
11+
x = np.linspace(1, 100, 20)
12+
y = np.linspace(1, 100, 20)
13+
14+
with plt.rc_context(rc=RC_PARAMS):
15+
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
16+
(l,) = ax.plot(x, y)
17+
ax.set_ylim([20, 80])
18+
ax.set_xlim([20, 80])
19+
cleanfigure._recursive_cleanfigure(fig)
20+
21+
1022
def test_pruneOutsideBox():
1123
"""test against matlab2tikz implementation
1224
@@ -389,7 +401,7 @@ def test_plot(self):
389401
ax.set_xlim([20, 80])
390402
raw = get_tikz_code()
391403

392-
cleanfigure.cleanfigure(fig, ax)
404+
cleanfigure.cleanfigure(fig)
393405
clean = get_tikz_code()
394406

395407
# Use number of lines to test if it worked.
@@ -411,7 +423,7 @@ def test_step(self):
411423
ax.set_ylim([20, 80])
412424
ax.set_xlim([20, 80])
413425
with pytest.warns(Warning):
414-
cleanfigure.cleanfigure(fig, ax)
426+
cleanfigure.cleanfigure(fig)
415427

416428
def test_scatter(self):
417429
# TODO: scatter plots are represented through axes.collections. Currently, this is simply ignored and nothing is done.
@@ -424,7 +436,7 @@ def test_scatter(self):
424436
ax.set_ylim([20, 80])
425437
ax.set_xlim([20, 80])
426438
with pytest.warns(Warning):
427-
cleanfigure.cleanfigure(fig, ax)
439+
cleanfigure.cleanfigure(fig)
428440

429441
def test_bar(self):
430442

@@ -436,7 +448,7 @@ def test_bar(self):
436448
ax.set_ylim([20, 80])
437449
ax.set_xlim([20, 80])
438450
with pytest.warns(Warning):
439-
cleanfigure.cleanfigure(fig, ax)
451+
cleanfigure.cleanfigure(fig)
440452

441453
def test_hist(self):
442454
"""creates same test case as bar"""
@@ -449,7 +461,7 @@ def test_hist(self):
449461
ax.set_ylim([20, 80])
450462
ax.set_xlim([20, 80])
451463
with pytest.warns(Warning):
452-
cleanfigure.cleanfigure(fig, ax)
464+
cleanfigure.cleanfigure(fig)
453465

454466
def test_plot3d(self):
455467
from mpl_toolkits.mplot3d import Axes3D
@@ -469,7 +481,7 @@ def test_plot3d(self):
469481
ax.view_init(30, 30)
470482
raw = get_tikz_code(fig)
471483

472-
cleanfigure.cleanfigure(fig, ax)
484+
cleanfigure.cleanfigure(fig)
473485
clean = get_tikz_code()
474486

475487
# Use number of lines to test if it worked.
@@ -491,7 +503,7 @@ def test_scatter3d(self):
491503
ax.set_ylim([20, 80])
492504
ax.set_zlim([0, 80])
493505
with pytest.warns(Warning):
494-
cleanfigure.cleanfigure(fig, ax)
506+
cleanfigure.cleanfigure(fig)
495507

496508
def test_wireframe3D(self):
497509
from mpl_toolkits.mplot3d import axes3d
@@ -505,7 +517,7 @@ def test_wireframe3D(self):
505517
# Plot a basic wireframe.
506518
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
507519
with pytest.warns(Warning):
508-
cleanfigure.cleanfigure(fig, ax)
520+
cleanfigure.cleanfigure(fig)
509521

510522
def test_surface3D(self):
511523
from mpl_toolkits.mplot3d import Axes3D
@@ -535,7 +547,7 @@ def test_surface3D(self):
535547
fig.colorbar(surf, shrink=0.5, aspect=5)
536548

537549
with pytest.warns(Warning):
538-
cleanfigure.cleanfigure(fig, ax)
550+
cleanfigure.cleanfigure(fig)
539551

540552
def test_trisurface3D(Self):
541553
from mpl_toolkits.mplot3d import Axes3D
@@ -567,7 +579,7 @@ def test_trisurface3D(Self):
567579

568580
ax.plot_trisurf(x, y, z, linewidth=0.2, antialiased=True)
569581
with pytest.warns(Warning):
570-
cleanfigure.cleanfigure(fig, ax)
582+
cleanfigure.cleanfigure(fig)
571583

572584
def test_contour3D(self):
573585
from mpl_toolkits.mplot3d import axes3d
@@ -581,7 +593,7 @@ def test_contour3D(self):
581593
cset = ax.contour(X, Y, Z, cmap=cm.coolwarm)
582594
ax.clabel(cset, fontsize=9, inline=1)
583595
with pytest.warns(Warning):
584-
cleanfigure.cleanfigure(fig, ax)
596+
cleanfigure.cleanfigure(fig)
585597

586598
def test_polygon3D(self):
587599
from mpl_toolkits.mplot3d import Axes3D
@@ -615,7 +627,7 @@ def cc(arg):
615627
ax.set_zlabel('Z')
616628
ax.set_zlim3d(0, 1)
617629
with pytest.warns(Warning):
618-
cleanfigure.cleanfigure(fig, ax)
630+
cleanfigure.cleanfigure(fig)
619631

620632
def test_bar3D(self):
621633
from mpl_toolkits.mplot3d import Axes3D
@@ -640,7 +652,7 @@ def test_bar3D(self):
640652
ax.set_ylabel('Y')
641653
ax.set_zlabel('Z')
642654
with pytest.warns(Warning):
643-
cleanfigure.cleanfigure(fig, ax)
655+
cleanfigure.cleanfigure(fig)
644656

645657
def test_quiver3D(self):
646658
from mpl_toolkits.mplot3d import axes3d
@@ -665,7 +677,7 @@ def test_quiver3D(self):
665677

666678
ax.quiver(x, y, z, u, v, w, length=0.1, normalize=True)
667679
with pytest.warns(Warning):
668-
cleanfigure.cleanfigure(fig, ax)
680+
cleanfigure.cleanfigure(fig)
669681

670682
def test_2D_in_3D(self):
671683
from mpl_toolkits.mplot3d import Axes3D
@@ -705,7 +717,7 @@ def test_2D_in_3D(self):
705717
# on the plane y=0
706718
ax.view_init(elev=20., azim=-35)
707719
with pytest.warns(Warning):
708-
cleanfigure.cleanfigure(fig, ax)
720+
cleanfigure.cleanfigure(fig)
709721

710722

711723
class Test_lineplot:
@@ -724,7 +736,7 @@ def test_line_no_markers(self):
724736
ax.set_xlim([20, 80])
725737
raw = get_tikz_code()
726738

727-
cleanfigure.cleanfigure(fig, ax)
739+
cleanfigure.cleanfigure(fig)
728740
clean = get_tikz_code()
729741

730742
# Use number of lines to test if it worked.
@@ -750,7 +762,7 @@ def test_no_line_markers(self):
750762
ax.set_xlim([20, 80])
751763
raw = get_tikz_code()
752764

753-
cleanfigure.cleanfigure(fig, ax)
765+
cleanfigure.cleanfigure(fig)
754766
clean = get_tikz_code()
755767

756768
# Use number of lines to test if it worked.
@@ -776,7 +788,7 @@ def test_line_markers(self):
776788
ax.set_xlim([20, 80])
777789
raw = get_tikz_code()
778790

779-
cleanfigure.cleanfigure(fig, ax)
791+
cleanfigure.cleanfigure(fig)
780792
clean = get_tikz_code()
781793

782794
# Use number of lines to test if it worked.
@@ -799,7 +811,7 @@ def test_sine(self):
799811
ax.set_ylim([-1, 1])
800812
raw = get_tikz_code()
801813

802-
cleanfigure.cleanfigure(fig, ax)
814+
cleanfigure.cleanfigure(fig)
803815
clean = get_tikz_code()
804816

805817
# Use number of lines to test if it worked.

tikzplotlib/cleanfigure.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import matplotlib
2+
import matplotlib as mpl
33
from matplotlib import pyplot as plt
44
import mpl_toolkits
55
from mpl_toolkits.mplot3d import Axes3D
@@ -14,7 +14,7 @@
1414
STEP_DRAW_STYLES = ["steps-pre", "steps-post", "steps-mid"]
1515

1616

17-
def cleanfigure(fighandle=None, axhandle=None, target_resolution=600, scalePrecision=1.0):
17+
def cleanfigure(fig=None, targetResolution=600, scalePrecision=1.0):
1818
"""Cleans figure as a preparation for tikz export.
1919
This will minimize the number of points required for the tikz figure.
2020
If the figure has subplots, it will recursively clean then up.
@@ -66,40 +66,62 @@ def cleanfigure(fighandle=None, axhandle=None, target_resolution=600, scalePreci
6666
6767
```
6868
"""
69-
if fighandle is None and axhandle is None:
70-
fighandle = plt.gcf()
71-
# recurse into subplots
72-
for axhandle in fighandle.axes:
73-
cleanfigure(fighandle, axhandle, target_resolution, scalePrecision)
74-
elif fighandle is None and (axhandle is not None):
75-
fighandle = axhandle.get_figure()
76-
elif (fighandle is not None) and (axhandle is None):
77-
# recurse into subplots
78-
for axhandle in fighandle.axes:
79-
cleanfigure(fighandle, axhandle, target_resolution, scalePrecision)
80-
81-
82-
# clean up ax.plot and ax.step
83-
for linehandle in axhandle.lines:
84-
if type(linehandle) in [matplotlib.lines.Line2D, mpl_toolkits.mplot3d.art3d.Line3D]:
85-
_cleanline(fighandle, axhandle, linehandle, target_resolution, scalePrecision)
69+
if fig is None:
70+
fig = plt.gcf()
71+
elif fig == "gcf": # tikzplotlib syntax
72+
fig = plt.gcf()
73+
_recursive_cleanfigure(fig, targetResolution=targetResolution, scalePrecision=scalePrecision)
74+
75+
76+
77+
def _recursive_cleanfigure(obj, targetResolution=600, scalePrecision=1.0):
78+
for child in obj.get_children():
79+
if isinstance(child, mpl.spines.Spine):
80+
pass
81+
if isinstance(child, mpl.axes.Axes):
82+
# Note: containers contain Patches but are not child objects.
83+
# This is a problem because a bar plot creates a Barcontainer.
84+
_clean_containers(child)
85+
_recursive_cleanfigure(child, targetResolution=targetResolution, scalePrecision=scalePrecision)
86+
elif isinstance(child, mpl_toolkits.mplot3d.axes3d.Axes3D):
87+
_clean_containers(child)
88+
_recursive_cleanfigure(child, targetResolution=targetResolution, scalePrecision=scalePrecision)
89+
elif isinstance(child, mpl.lines.Line2D):
90+
ax = child.axes
91+
fig = ax.figure
92+
_cleanline(fig, ax, linehandle=child, targetResolution=targetResolution, scalePrecision=scalePrecision)
93+
elif isinstance(child, mpl_toolkits.mplot3d.art3d.Line3D):
94+
ax = child.axes
95+
fig = ax.figure
96+
_cleanline(fig, ax, linehandle=child, targetResolution=targetResolution, scalePrecision=scalePrecision)
97+
elif isinstance(child, mpl.image.AxesImage):
98+
pass
99+
elif isinstance(child, mpl.patches.Patch):
100+
pass
101+
elif isinstance(child, mpl.collections.PathCollection):
102+
import warnings
103+
warnings.warn("Cleaning Path Collections (scatter plot) is not supported yet.")
104+
elif isinstance(child, mpl.collections.LineCollection):
105+
import warnings
106+
warnings.warn("Cleaning Line Collections (scatter plot) is not supported yet.")
107+
elif isinstance(child, mpl_toolkits.mplot3d.art3d.Line3DCollection):
108+
import warnings
109+
warnings.warn("Cleaning Line3DCollection is not supported yet.")
110+
elif isinstance(child, mpl_toolkits.mplot3d.art3d.Poly3DCollection):
111+
import warnings
112+
warnings.warn("Cleaning Poly3DCollections is not supported yet.")
86113
else:
87-
raise NotImplementedError
114+
pass
88115

89-
# clean up ax.bar and ax.hist
90-
for container in axhandle.containers:
91-
if type(container) == matplotlib.container.BarContainer:
116+
def _clean_containers(axes):
117+
"""Containers are not children of axes. They need to be visited separately"""
118+
for container in axes.containers:
119+
if isinstance(container, mpl.container.BarContainer):
92120
import warnings
93-
warnings.warn("bar and histogram simplification not implemented. Doing Nothing")
94-
95-
# clean up ax.scatter
96-
for collection in axhandle.collections:
97-
import warnings
98-
warnings.warn("scatter simplification not implemented. Doing Nothing")
99-
121+
warnings.warn("Cleaning Bar Container (bar plot) is not supported yet.")
100122

101123

102-
def _cleanline(fighandle, axhandle, linehandle, target_resolution, scalePrecision):
124+
def _cleanline(fighandle, axhandle, linehandle, targetResolution, scalePrecision):
103125
"""Clean a 2D Line plot figure.
104126
105127
Parameters
@@ -126,7 +148,7 @@ def _cleanline(fighandle, axhandle, linehandle, target_resolution, scalePrecisio
126148
else:
127149
_pruneOutsideBox(fighandle, axhandle, linehandle)
128150
_movePointscloser(fighandle, axhandle, linehandle)
129-
_simplifyLine(fighandle, axhandle, linehandle, target_resolution)
151+
_simplifyLine(fighandle, axhandle, linehandle, targetResolution)
130152
_limitPrecision(fighandle, axhandle, linehandle, scalePrecision)
131153

132154

@@ -357,7 +379,7 @@ def _isInBox(data, xLim, yLim):
357379

358380

359381
def _lineIs3D(linehandle):
360-
return type(linehandle) == mpl_toolkits.mplot3d.art3d.Line3D
382+
return isinstance(linehandle, mpl_toolkits.mplot3d.art3d.Line3D)
361383

362384

363385
def _axIs3D(axhandle):

0 commit comments

Comments
 (0)