Skip to content

Commit cba254e

Browse files
implements subplot support for cleanfigure
* `cleanfigure.cleanfigure` now recursively visits all axes. * implemented `cleanfigure.diff` to replace matlab's `diff` function * bug fixes in `removeNaNs` * added octave code docstring in subplot test
1 parent 7cf0d8d commit cba254e

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

test/test_cleanfigure.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,30 @@ def test_sine(self):
488488

489489
class Test_subplots:
490490
def test_subplot(self):
491+
"""octave code
492+
493+
```octave
494+
addpath ("../matlab2tikz/src")
495+
496+
x = linspace(1, 100, 20);
497+
y1 = linspace(1, 100, 20);
498+
499+
figure
500+
subplot(2, 2, 1)
501+
plot(x, y1, "-")
502+
subplot(2, 2, 2)
503+
plot(x, y1, "-")
504+
subplot(2, 2, 3)
505+
plot(x, y1, "-")
506+
subplot(2, 2, 4)
507+
plot(x, y1, "-")
508+
xlim([20, 80])
509+
ylim([20, 80])
510+
set(gcf,'Units','Inches');
511+
set(gcf,'Position',[2.5 2.5 5 5])
512+
cleanfigure;
513+
```
514+
"""
491515
from tikzplotlib import get_tikz_code
492516

493517
x = np.linspace(1, 100, 20)

tikzplotlib/cleanfigure.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@
1010

1111

1212
def cleanfigure(fighandle=None, axhandle=None, target_resolution=600, scalePrecision=1.0):
13-
"""cleans figure as a preparation for tikz export.
13+
"""Cleans figure as a preparation for tikz export.
1414
This will minimize the number of points required for the tikz figure.
15+
If the figure has subplots, it will recursively clean then up.
1516
1617
Note that this function modifies the figure directly (impure function).
18+
1719
1820
Parameters
1921
----------
2022
fighandle : obj, optional
21-
matplotlib figure handle object. If not provided, it is obtained from gcf(), by default None
23+
matplotlib figure handle object. If not provided, it is obtained from `plt.gcf()`, by default None
2224
axhandle : obj, optional
23-
matplotlib figure handle object. If not provided, it is obtained from gca(), by default None
25+
matplotlib figure handle object. If not provided, it is obtained from `plt.gcf().axes`, by default None
2426
target_resolution : int, list of int or np.array
2527
target resolution of final figure in PPI. If a scalar integer is provided, it is assumed to be square in both axis.
2628
If a list or an np.array is provided, it is interpreted as [H, W], by default 600
@@ -29,11 +31,15 @@ def cleanfigure(fighandle=None, axhandle=None, target_resolution=600, scalePreci
2931
"""
3032
if fighandle is None and axhandle is None:
3133
fighandle = plt.gcf()
32-
axhandle = plt.gca()
34+
# recurse into subplots
35+
for axhandle in fighandle.axes:
36+
cleanfigure(fighandle, axhandle, target_resolution, scalePrecision)
3337
elif fighandle is None and (axhandle is not None):
3438
fighandle = axhandle.get_figure()
3539
elif (fighandle is not None) and (axhandle is None):
36-
axhandle = fighandle.axes[0]
40+
# recurse into subplots
41+
for axhandle in fighandle.axes:
42+
cleanfigure(fighandle, axhandle, target_resolution, scalePrecision)
3743

3844
# Note: ax.scatter and ax.plot create Line2D objects in a property ax.lines
3945
# ax.bar creates BarContainer objects in a property ax.bar
@@ -167,6 +173,17 @@ def removeData(data, id_remove):
167173
return np.concatenate([xData, yData], axis=1)
168174

169175

176+
def diff(x, *args, **kwargs):
177+
"""modification of np.diff(x, *args, **kwargs).
178+
- If x is empty, return np.array([False])
179+
- else: return np.diff(x, *args, **kwargs)
180+
"""
181+
if isempty(x):
182+
return np.array([False])
183+
else:
184+
return np.diff(x, *args, **kwargs)
185+
186+
170187
def removeNaNs(data):
171188
"""Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consecutive ones.
172189
@@ -184,20 +201,24 @@ def removeNaNs(data):
184201
xData, yData = np.split(data, 2, 1)
185202
id_nan = np.any(np.isnan(data), axis=1)
186203
id_remove = np.argwhere(id_nan).reshape((-1,))
187-
id_remove = id_remove[
188-
np.concatenate(
189-
[np.array([True,]).reshape((-1,)), np.diff(id_remove, axis=0) == 1]
190-
)
191-
]
204+
if isempty(id_remove):
205+
pass
206+
else:
207+
id_remove = id_remove[
208+
np.concatenate(
209+
[diff(id_remove, axis=0) == 1, np.array([False,]).reshape((-1,))]
210+
)
211+
]
192212

193213
id_first = np.argwhere(np.logical_not(id_nan))[0]
194214
id_last = np.argwhere(np.logical_not(id_nan))[-1]
195215

196-
if elements(id_first) == 0:
216+
if isempty(id_first):
217+
# remove entire data
197218
id_remove = np.arange(len(xData))
198219
else:
199220
id_remove = np.concatenate(
200-
[np.arange(1, id_first - 1), id_remove, np.arange(id_last + 1, len(xData))]
221+
[np.arange(0, id_first), id_remove, np.arange(id_last + 1, len(xData))]
201222
)
202223
xData = np.delete(xData, id_remove, axis=0)
203224
yData = np.delete(yData, id_remove, axis=0)

0 commit comments

Comments
 (0)