Skip to content

Commit c9f6372

Browse files
refactoring
1 parent 87cd23e commit c9f6372

File tree

2 files changed

+131
-380
lines changed

2 files changed

+131
-380
lines changed

test/test_cleanfigure.py

Lines changed: 62 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pytest
33
from matplotlib import pyplot as plt
44

5-
from tikzplotlib import cleanfigure, get_tikz_code
5+
from tikzplotlib import get_tikz_code
6+
from tikzplotlib import _cleanfigure as cleanfigure
67

78
RC_PARAMS = {"figure.figsize": [5, 5], "figure.dpi": 220, "pgf.rcfonts": False}
89

@@ -45,7 +46,7 @@ def test_pruneOutsideBox():
4546
(l,) = ax.plot(x, y)
4647
ax.set_ylim([20, 80])
4748
ax.set_xlim([20, 80])
48-
cleanfigure._pruneOutsideBox(fig, ax, l)
49+
cleanfigure._prune_outside_box(fig, ax, l)
4950
assert l.get_xdata().shape == (14,)
5051
plt.close("all")
5152

@@ -75,7 +76,7 @@ def test_replaceDataWithNaN():
7576
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
7677
(l,) = ax.plot(xData, yData)
7778

78-
cleanfigure._replaceDataWithNan(l, id_replace)
79+
cleanfigure._replace_data_with_NaN(l, id_replace)
7980

8081
newdata = np.stack(l.get_data(), axis=1)
8182
assert newdata.shape == data.shape
@@ -107,7 +108,7 @@ def test_removeData():
107108
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
108109
(l,) = ax.plot(xData, yData)
109110

110-
cleanfigure._removeData(l, id_remove)
111+
cleanfigure._remove_data(l, id_remove)
111112
newdata = np.stack(l.get_data(), axis=1)
112113
assert newdata.shape == (14, 2)
113114
plt.close("all")
@@ -136,9 +137,9 @@ def test_removeNaNs():
136137
with plt.rc_context(rc=RC_PARAMS):
137138
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
138139
(l,) = ax.plot(xData, yData)
139-
cleanfigure._replaceDataWithNan(l, id_replace)
140-
cleanfigure._removeData(l, id_remove)
141-
cleanfigure._removeNaNs(l)
140+
cleanfigure._replace_data_with_NaN(l, id_replace)
141+
cleanfigure._remove_data(l, id_remove)
142+
cleanfigure._remove_NaNs(l)
142143
newdata = np.stack(l.get_data(), axis=1)
143144
assert not np.any(np.isnan(newdata))
144145
assert newdata.shape == (12, 2)
@@ -194,7 +195,7 @@ def test_getVisualLimits():
194195
(l,) = ax.plot(x, y)
195196
ax.set_xlim([20, 80])
196197
ax.set_ylim([20, 80])
197-
xLim, yLim = cleanfigure._getVisualLimits(fig, ax)
198+
xLim, yLim = cleanfigure._get_visual_limits(fig, ax)
198199
assert np.allclose(xLim, np.array([20, 80]))
199200
assert np.allclose(yLim, np.array([20, 80]))
200201
plt.close("all")
@@ -227,8 +228,8 @@ def test_movePointsCloser():
227228
(l,) = ax.plot(x, y)
228229
ax.set_ylim([20, 80])
229230
ax.set_xlim([20, 80])
230-
cleanfigure._pruneOutsideBox(fig, ax, l)
231-
cleanfigure._movePointscloser(fig, ax, l)
231+
cleanfigure._prune_outside_box(fig, ax, l)
232+
cleanfigure._move_points_closer(fig, ax, l)
232233
assert l.get_xdata().shape == (14,)
233234
plt.close("all")
234235

@@ -260,49 +261,14 @@ def test_simplifyLine():
260261
(l,) = ax.plot(x, y)
261262
ax.set_ylim([20, 80])
262263
ax.set_xlim([20, 80])
263-
cleanfigure._pruneOutsideBox(fig, ax, l)
264-
cleanfigure._movePointscloser(fig, ax, l)
265-
cleanfigure._simplifyLine(fig, ax, l, 600)
264+
cleanfigure._prune_outside_box(fig, ax, l)
265+
cleanfigure._move_points_closer(fig, ax, l)
266+
cleanfigure._simplify_line(fig, ax, l, 600)
266267
assert l.get_xdata().shape == (2,)
267268
assert l.get_ydata().shape == (2,)
268269
plt.close("all")
269270

270271

271-
# def test_simplifyStairs():
272-
# """octave code
273-
274-
# ```octave
275-
# %% example 4
276-
277-
# addpath ("../matlab2tikz/src")
278-
279-
# x = linspace(1, 100, 20);
280-
# y1 = linspace(1, 100, 20);
281-
282-
# figure
283-
# stairs(x, y1)
284-
# xlim([20, 80])
285-
# ylim([20, 80])
286-
# set(gcf,'Units','Inches');
287-
# set(gcf,'Position',[2.5 2.5 5 5])
288-
# cleanfigure;
289-
# ```
290-
# """
291-
# # TODO: it looks like matlab changes the data to be plotted when using `stairs` command,
292-
# # whereas matplotlib stores the same data but displays it as a step.
293-
# x = np.linspace(1, 100, 20)
294-
# y = np.linspace(1, 100, 20)
295-
296-
# with plt.rc_context(rc=RC_PARAMS):
297-
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
298-
# (l,) = ax.step(x, y, where="post")
299-
# ax.set_ylim([20, 80])
300-
# ax.set_xlim([20, 80])
301-
# cleanfigure.pruneOutsideBox(fig, ax, l)
302-
# cleanfigure.movePointscloser(fig, ax, l)
303-
# cleanfigure.simplifyStairs(fig, ax, l)
304-
305-
306272
def test_limitPrecision():
307273
"""octave code
308274
```octave
@@ -330,10 +296,10 @@ def test_limitPrecision():
330296
(l,) = ax.plot(x, y)
331297
ax.set_ylim([20, 80])
332298
ax.set_xlim([20, 80])
333-
cleanfigure._pruneOutsideBox(fig, ax, l)
334-
cleanfigure._movePointscloser(fig, ax, l)
335-
cleanfigure._simplifyLine(fig, ax, l, 600)
336-
cleanfigure._limitPrecision(fig, ax, l, 1)
299+
cleanfigure._prune_outside_box(fig, ax, l)
300+
cleanfigure._move_points_closer(fig, ax, l)
301+
cleanfigure._simplify_line(fig, ax, l, 600)
302+
cleanfigure._limit_precision(fig, ax, l, 1)
337303
assert l.get_xdata().shape == (2,)
338304
assert l.get_ydata().shape == (2,)
339305
plt.close("all")
@@ -378,7 +344,7 @@ def test_opheimSimplify():
378344
)
379345
y = x.copy()
380346
tol = 0.02
381-
mask = cleanfigure._opheimSimplify(x, y, tol)
347+
mask = cleanfigure._opheim_simplify(x, y, tol)
382348
assert mask.shape == (12,)
383349
assert np.allclose(mask * 1, np.array([1] + [0] * 10 + [1]))
384350

@@ -422,7 +388,7 @@ def test_plot(self):
422388
ax.set_xlim([20, 80])
423389
raw = get_tikz_code()
424390

425-
cleanfigure.cleanfigure(fig)
391+
cleanfigure.clean_figure(fig)
426392
clean = get_tikz_code()
427393

428394
# Use number of lines to test if it worked.
@@ -446,7 +412,7 @@ def test_step(self):
446412
ax.set_ylim([20, 80])
447413
ax.set_xlim([20, 80])
448414
with pytest.warns(Warning):
449-
cleanfigure.cleanfigure(fig)
415+
cleanfigure.clean_figure(fig)
450416
plt.close("all")
451417

452418
def test_scatter(self):
@@ -461,7 +427,7 @@ def test_scatter(self):
461427
ax.set_ylim([20, 80])
462428
ax.set_xlim([20, 80])
463429
with pytest.warns(Warning):
464-
cleanfigure.cleanfigure(fig)
430+
cleanfigure.clean_figure(fig)
465431
plt.close("all")
466432

467433
def test_bar(self):
@@ -475,7 +441,7 @@ def test_bar(self):
475441
ax.set_ylim([20, 80])
476442
ax.set_xlim([20, 80])
477443
with pytest.warns(Warning):
478-
cleanfigure.cleanfigure(fig)
444+
cleanfigure.clean_figure(fig)
479445
plt.close("all")
480446

481447
def test_hist(self):
@@ -489,7 +455,7 @@ def test_hist(self):
489455
ax.set_ylim([20, 80])
490456
ax.set_xlim([20, 80])
491457
with pytest.warns(Warning):
492-
cleanfigure.cleanfigure(fig)
458+
cleanfigure.clean_figure(fig)
493459
plt.close("all")
494460

495461
def test_plot3d(self):
@@ -511,7 +477,7 @@ def test_plot3d(self):
511477
ax.view_init(30, 30)
512478
raw = get_tikz_code(fig)
513479

514-
cleanfigure.cleanfigure(fig)
480+
cleanfigure.clean_figure(fig)
515481
clean = get_tikz_code()
516482

517483
# Use number of lines to test if it worked.
@@ -534,7 +500,7 @@ def test_scatter3d(self):
534500
ax.set_ylim([20, 80])
535501
ax.set_zlim([0, 80])
536502
with pytest.warns(Warning):
537-
cleanfigure.cleanfigure(fig)
503+
cleanfigure.clean_figure(fig)
538504
plt.close("all")
539505

540506
def test_wireframe3D(self):
@@ -551,7 +517,7 @@ def test_wireframe3D(self):
551517
# Plot a basic wireframe.
552518
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
553519
with pytest.warns(Warning):
554-
cleanfigure.cleanfigure(fig)
520+
cleanfigure.clean_figure(fig)
555521
plt.close("all")
556522

557523
def test_surface3D(self):
@@ -584,7 +550,7 @@ def test_surface3D(self):
584550
fig.colorbar(surf, shrink=0.5, aspect=5)
585551

586552
with pytest.warns(Warning):
587-
cleanfigure.cleanfigure(fig)
553+
cleanfigure.clean_figure(fig)
588554
plt.close("all")
589555

590556
def test_trisurface3D(self):
@@ -620,7 +586,7 @@ def test_trisurface3D(self):
620586

621587
ax.plot_trisurf(x, y, z, linewidth=0.2, antialiased=True)
622588
with pytest.warns(Warning):
623-
cleanfigure.cleanfigure(fig)
589+
cleanfigure.clean_figure(fig)
624590
plt.close("all")
625591

626592
def test_contour3D(self):
@@ -636,7 +602,7 @@ def test_contour3D(self):
636602
cset = ax.contour(X, Y, Z, cmap=cm.coolwarm)
637603
ax.clabel(cset, fontsize=9, inline=1)
638604
with pytest.warns(Warning):
639-
cleanfigure.cleanfigure(fig)
605+
cleanfigure.clean_figure(fig)
640606
plt.close("all")
641607

642608
def test_polygon3D(self):
@@ -677,7 +643,7 @@ def cc(arg):
677643
ax.set_zlabel("Z")
678644
ax.set_zlim3d(0, 1)
679645
with pytest.warns(Warning):
680-
cleanfigure.cleanfigure(fig)
646+
cleanfigure.clean_figure(fig)
681647
plt.close("all")
682648

683649
def test_bar3D(self):
@@ -702,7 +668,7 @@ def test_bar3D(self):
702668
ax.set_ylabel("Y")
703669
ax.set_zlabel("Z")
704670
with pytest.warns(Warning):
705-
cleanfigure.cleanfigure(fig)
671+
cleanfigure.clean_figure(fig)
706672
plt.close("all")
707673

708674
def test_quiver3D(self):
@@ -733,7 +699,7 @@ def test_quiver3D(self):
733699

734700
ax.quiver(x, y, z, u, v, w, length=0.1, normalize=True)
735701
with pytest.warns(Warning):
736-
cleanfigure.cleanfigure(fig)
702+
cleanfigure.clean_figure(fig)
737703
plt.close("all")
738704

739705
def test_2D_in_3D(self):
@@ -774,7 +740,7 @@ def test_2D_in_3D(self):
774740
# on the plane y=0
775741
ax.view_init(elev=20.0, azim=-35)
776742
with pytest.warns(Warning):
777-
cleanfigure.cleanfigure(fig)
743+
cleanfigure.clean_figure(fig)
778744
plt.close("all")
779745

780746

@@ -798,7 +764,7 @@ def test_line_no_markers(self):
798764
ax.set_xlim([20, 80])
799765
raw = get_tikz_code()
800766

801-
cleanfigure.cleanfigure(fig)
767+
cleanfigure.clean_figure(fig)
802768
clean = get_tikz_code()
803769

804770
# Use number of lines to test if it worked.
@@ -827,7 +793,7 @@ def test_no_line_markers(self):
827793
ax.set_xlim([20, 80])
828794
raw = get_tikz_code()
829795

830-
cleanfigure.cleanfigure(fig)
796+
cleanfigure.clean_figure(fig)
831797
clean = get_tikz_code()
832798

833799
# Use number of lines to test if it worked.
@@ -856,7 +822,7 @@ def test_line_markers(self):
856822
ax.set_xlim([20, 80])
857823
raw = get_tikz_code()
858824

859-
cleanfigure.cleanfigure(fig)
825+
cleanfigure.clean_figure(fig)
860826
clean = get_tikz_code()
861827

862828
# Use number of lines to test if it worked.
@@ -881,7 +847,7 @@ def test_sine(self):
881847
ax.set_ylim([-1, 1])
882848
raw = get_tikz_code()
883849

884-
cleanfigure.cleanfigure(fig)
850+
cleanfigure.clean_figure(fig)
885851
clean = get_tikz_code()
886852

887853
# Use number of lines to test if it worked.
@@ -937,7 +903,7 @@ def test_subplot(self):
937903
ax.set_xlim([20, 80])
938904
raw = get_tikz_code()
939905

940-
cleanfigure.cleanfigure(fig)
906+
cleanfigure.clean_figure(fig)
941907
clean = get_tikz_code()
942908

943909
# Use number of lines to test if it worked.
@@ -974,7 +940,7 @@ def test_segmentVisible():
974940
dataIsInBox = np.array([0] * 4 + [1] * 12 + [0] * 4) == 1
975941
xLim = np.array([20, 80])
976942
yLim = np.array([20, 80])
977-
mask = cleanfigure._segmentVisible(data, dataIsInBox, xLim, yLim)
943+
mask = cleanfigure._segment_visible(data, dataIsInBox, xLim, yLim)
978944
assert np.allclose(mask * 1, np.array([0] * 3 + [1] * 13 + [0] * 3))
979945

980946

@@ -987,7 +953,7 @@ def test_crossLines():
987953
X2 = data[1:, :]
988954
X3 = np.array([80, 20])
989955
X4 = np.array([80, 80])
990-
Lambda = cleanfigure._crossLines(X1, X2, X3, X4)
956+
Lambda = cleanfigure._cross_lines(X1, X2, X3, X4)
991957

992958
expected_result = np.array(
993959
[
@@ -1024,7 +990,7 @@ def test_segmentsIntersect():
1024990
X2 = data[1:, :]
1025991
X3 = np.array([80, 20])
1026992
X4 = np.array([80, 80])
1027-
mask = cleanfigure._segmentsIntersect(X1, X2, X3, X4)
993+
mask = cleanfigure._segments_intersect(X1, X2, X3, X4)
1028994
assert np.allclose(mask * 1, np.zeros_like(mask))
1029995

1030996

@@ -1061,3 +1027,23 @@ def test_corners3D():
10611027

10621028
assert corners.shape == (8, 3)
10631029
assert np.sum(corners) == 0
1030+
1031+
1032+
def test_corners2D():
1033+
xLim = np.array([20, 80])
1034+
yLim = np.array([20, 80])
1035+
corners = cleanfigure._corners2D(xLim, yLim)
1036+
1037+
import itertools
1038+
1039+
expected_output = tuple(np.array(t) for t in itertools.product([20, 80], [20, 80]))
1040+
assert np.allclose(corners, expected_output)
1041+
1042+
1043+
def test_getHeightWidthInPixels():
1044+
with plt.rc_context(rc=RC_PARAMS):
1045+
fig, axes = plt.subplots(1, 1, figsize=(5, 5))
1046+
w, h = cleanfigure._get_width_height_in_pixels(fig, [600, 400])
1047+
assert w == 600 and h == 400
1048+
w, h = cleanfigure._get_width_height_in_pixels(fig, 600)
1049+
assert w == h

0 commit comments

Comments
 (0)