Skip to content

Commit 7e5de37

Browse files
scatter plot support
1 parent 17f70ba commit 7e5de37

File tree

2 files changed

+251
-231
lines changed

2 files changed

+251
-231
lines changed

test/test_cleanfigure.py

Lines changed: 89 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -8,58 +8,73 @@
88
RC_PARAMS = {"figure.figsize": [5, 5], "figure.dpi": 220, "pgf.rcfonts": False}
99

1010

11-
def test_clean_figure():
12-
x = np.linspace(1, 100, 20)
13-
y = np.linspace(1, 100, 20)
14-
15-
with plt.rc_context(rc=RC_PARAMS):
16-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
17-
(l,) = ax.plot(x, y)
18-
ax.set_ylim([20, 80])
19-
ax.set_xlim([20, 80])
20-
cleanfigure.clean_figure(fig)
21-
plt.close("all")
11+
class Test_pruneOutsideBox:
12+
def test_pruneOutsideBox2D(self):
13+
"""test against matlab2tikz implementation
2214
15+
octave code to generate baseline results.
16+
Note that octave has indexing 1...N, whereas python has indexing 0...N-1.
17+
```octave
18+
x = linspace(1, 100, 20);
19+
y1 = linspace(1, 100, 20);
2320
24-
def test_pruneOutsideBox():
25-
"""test against matlab2tikz implementation
21+
figure
22+
plot(x, y1)
23+
xlim([20, 80])
24+
ylim([20, 80])
25+
cleanfigure;
26+
```
27+
"""
28+
x = np.linspace(1, 100, 20)
29+
y = np.linspace(1, 100, 20)
2630

27-
octave code to generate baseline results.
28-
Note that octave has indexing 1...N, whereas python has indexing 0...N-1.
29-
```octave
30-
x = linspace(1, 100, 20);
31-
y1 = linspace(1, 100, 20);
31+
with plt.rc_context(rc=RC_PARAMS):
32+
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
33+
(l,) = ax.plot(x, y)
34+
ax.set_ylim([20, 80])
35+
ax.set_xlim([20, 80])
36+
axhandle = ax
37+
linehandle = l
38+
fighandle = fig
39+
data, is3D = cleanfigure._get_line_data(linehandle)
40+
xLim, yLim = cleanfigure._get_visual_limits(fighandle, axhandle)
41+
visual_data = cleanfigure._get_visual_data(axhandle, data, is3D)
42+
hasLines = cleanfigure._line_has_lines(linehandle)
43+
44+
data = cleanfigure._prune_outside_box(
45+
xLim, yLim, data, visual_data, is3D, hasLines
46+
)
47+
assert data.shape == (14, 2)
3248

33-
figure
34-
plot(x, y1)
35-
xlim([20, 80])
36-
ylim([20, 80])
37-
cleanfigure;
38-
```
39-
"""
40-
x = np.linspace(1, 100, 20)
41-
y = np.linspace(1, 100, 20)
49+
def test_pruneOutsideBox3D(self):
50+
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
51+
z = np.linspace(-2, 2, 100)
52+
r = z ** 2 + 1
53+
x = r * np.sin(theta)
54+
y = r * np.cos(theta)
4255

43-
with plt.rc_context(rc=RC_PARAMS):
44-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
45-
(l,) = ax.plot(x, y)
46-
ax.set_ylim([20, 80])
47-
ax.set_xlim([20, 80])
48-
axhandle = ax
49-
linehandle = l
50-
fighandle = fig
51-
xData, yData = cleanfigure._get_visual_data(axhandle, linehandle)
52-
visual_data = cleanfigure._stack_data_2D(xData, yData)
53-
data = cleanfigure._get_data(linehandle)
54-
xLim, yLim = cleanfigure._get_visual_limits(fighandle, axhandle)
55-
is3D = cleanfigure._lineIs3D(linehandle)
56-
hasLines = cleanfigure._line_has_lines(linehandle)
56+
with plt.rc_context(rc=RC_PARAMS):
57+
fig = plt.figure()
58+
ax = fig.add_subplot(111, projection="3d")
59+
(l,) = ax.plot(x, y, z)
60+
ax.set_xlim([-2, 2])
61+
ax.set_ylim([-2, 2])
62+
ax.set_zlim([-2, 2])
63+
ax.view_init(30, 30)
5764

58-
data = cleanfigure._prune_outside_box(
59-
xLim, yLim, data, visual_data, is3D, hasLines
60-
)
61-
assert data.shape == (14, 2)
65+
axhandle = ax
66+
linehandle = l
67+
fighandle = fig
68+
data, is3D = cleanfigure._get_line_data(linehandle)
69+
xLim, yLim = cleanfigure._get_visual_limits(fighandle, axhandle)
70+
visual_data = cleanfigure._get_visual_data(axhandle, data, is3D)
71+
hasLines = cleanfigure._line_has_lines(linehandle)
6272

73+
data = cleanfigure._prune_outside_box(
74+
xLim, yLim, data, visual_data, is3D, hasLines
75+
)
76+
assert data.shape == (86, 3)
77+
plt.close("all")
6378

6479

6580
def test_replaceDataWithNaN():
@@ -84,8 +99,8 @@ def test_replaceDataWithNaN():
8499
data = np.stack([xData, yData], axis=1)
85100

86101
newdata = cleanfigure._replace_data_with_NaN(data, id_replace, False)
87-
assert newdata.shape == data.shape
88-
assert np.any(np.isnan(newdata))
102+
assert newdata.shape == data.shape
103+
assert np.any(np.isnan(newdata))
89104

90105

91106
def test_removeData():
@@ -132,17 +147,13 @@ def test_removeNaNs():
132147
id_remove = np.array([1, 2, 3, 17, 18, 19])
133148
xData = np.linspace(1, 100, 20)
134149
yData = xData.copy()
150+
data = cleanfigure._stack_data_2D(xData, yData)
135151

136-
with plt.rc_context(rc=RC_PARAMS):
137-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
138-
(l,) = ax.plot(xData, yData)
139-
cleanfigure._replace_data_with_NaN(l, id_replace)
140-
cleanfigure._remove_data(l, id_remove)
141-
cleanfigure._remove_NaNs(l)
142-
newdata = np.stack(l.get_data(), axis=1)
143-
assert not np.any(np.isnan(newdata))
144-
assert newdata.shape == (12, 2)
145-
plt.close("all")
152+
data = cleanfigure._replace_data_with_NaN(data, id_replace, False)
153+
data = cleanfigure._remove_data(data, id_remove, False)
154+
data = cleanfigure._remove_NaNs(data)
155+
assert not np.any(np.isnan(data))
156+
assert data.shape == (12, 2)
146157

147158

148159
def test_isInBox():
@@ -200,104 +211,6 @@ def test_getVisualLimits():
200211
plt.close("all")
201212

202213

203-
def test_movePointsCloser():
204-
"""octave code
205-
```octave
206-
addpath ("../matlab2tikz/src")
207-
208-
x = linspace(1, 100, 20);
209-
y1 = linspace(1, 100, 20);
210-
211-
figure
212-
plot(x, y1)
213-
xlim([20, 80])
214-
ylim([20, 80])
215-
set(gcf,'Units','Inches');
216-
set(gcf,'Position',[2.5 2.5 5 5])
217-
cleanfigure;
218-
```
219-
"""
220-
x = np.linspace(1, 100, 20)
221-
y = np.linspace(1, 100, 20)
222-
223-
with plt.rc_context(rc=RC_PARAMS):
224-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
225-
(l,) = ax.plot(x, y)
226-
ax.set_ylim([20, 80])
227-
ax.set_xlim([20, 80])
228-
cleanfigure._prune_outside_box(fig, ax, l)
229-
cleanfigure._move_points_closer(fig, ax, l)
230-
assert l.get_xdata().shape == (14,)
231-
plt.close("all")
232-
233-
234-
def test_simplifyLine():
235-
"""octave code
236-
```octave
237-
addpath ("../matlab2tikz/src")
238-
239-
x = linspace(1, 100, 20);
240-
y1 = linspace(1, 100, 20);
241-
242-
figure
243-
plot(x, y1)
244-
xlim([20, 80])
245-
ylim([20, 80])
246-
set(gcf,'Units','Inches');
247-
set(gcf,'Position',[2.5 2.5 5 5])
248-
cleanfigure;
249-
```
250-
"""
251-
x = np.linspace(1, 100, 20)
252-
y = np.linspace(1, 100, 20)
253-
254-
with plt.rc_context(rc=RC_PARAMS):
255-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
256-
(l,) = ax.plot(x, y)
257-
ax.set_ylim([20, 80])
258-
ax.set_xlim([20, 80])
259-
cleanfigure._prune_outside_box(fig, ax, l)
260-
cleanfigure._move_points_closer(fig, ax, l)
261-
cleanfigure._simplify_line(fig, ax, l, 600)
262-
assert l.get_xdata().shape == (2,)
263-
assert l.get_ydata().shape == (2,)
264-
plt.close("all")
265-
266-
267-
def test_limitPrecision():
268-
"""octave code
269-
```octave
270-
addpath ("../matlab2tikz/src")
271-
272-
x = linspace(1, 100, 20);
273-
y1 = linspace(1, 100, 20);
274-
275-
figure
276-
plot(x, y1)
277-
xlim([20, 80])
278-
ylim([20, 80])
279-
set(gcf,'Units','Inches');
280-
set(gcf,'Position',[2.5 2.5 5 5])
281-
cleanfigure;
282-
```
283-
"""
284-
x = np.linspace(1, 100, 20)
285-
y = np.linspace(1, 100, 20)
286-
287-
with plt.rc_context(rc=RC_PARAMS):
288-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
289-
(l,) = ax.plot(x, y)
290-
ax.set_ylim([20, 80])
291-
ax.set_xlim([20, 80])
292-
cleanfigure._prune_outside_box(fig, ax, l)
293-
cleanfigure._move_points_closer(fig, ax, l)
294-
cleanfigure._simplify_line(fig, ax, l, 600)
295-
cleanfigure._limit_precision(fig, ax, l, 1)
296-
assert l.get_xdata().shape == (2,)
297-
assert l.get_ydata().shape == (2,)
298-
plt.close("all")
299-
300-
301214
def test_opheimSimplify():
302215
"""test path simplification
303216
@@ -467,22 +380,35 @@ def test_plot3d(self):
467380
# Use number of lines to test if it worked.
468381
numLinesRaw = raw.count("\n")
469382
numLinesClean = clean.count("\n")
383+
470384
assert numLinesRaw - numLinesClean == 14
471385
plt.close("all")
472386

473387
def test_scatter3d(self):
474-
x, y = np.meshgrid(np.linspace(1, 100, 20), np.linspace(1, 100, 20))
475-
z = np.abs(x - 50) + np.abs(y - 50)
388+
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
389+
z = np.linspace(-2, 2, 100)
390+
r = z ** 2 + 1
391+
x = r * np.sin(theta)
392+
y = r * np.cos(theta)
476393

477394
with plt.rc_context(rc=RC_PARAMS):
478395
fig = plt.figure()
479396
ax = fig.add_subplot(111, projection="3d")
480397
ax.scatter(x, y, z)
481-
ax.set_xlim([20, 80])
482-
ax.set_ylim([20, 80])
483-
ax.set_zlim([0, 80])
484-
with pytest.warns(Warning):
485-
cleanfigure.clean_figure(fig)
398+
ax.set_xlim([-2, 2])
399+
ax.set_ylim([-2, 2])
400+
ax.set_zlim([-2, 2])
401+
ax.view_init(30, 30)
402+
raw = get_tikz_code(fig)
403+
404+
cleanfigure.clean_figure(fig)
405+
clean = get_tikz_code()
406+
407+
# Use number of lines to test if it worked.
408+
numLinesRaw = raw.count("\n")
409+
numLinesClean = clean.count("\n")
410+
411+
assert numLinesRaw - numLinesClean == 14
486412
plt.close("all")
487413

488414
def test_wireframe3D(self):
@@ -717,7 +643,7 @@ def test_2D_in_3D(self):
717643
plt.close("all")
718644

719645

720-
class Test_lineplot:
646+
class Test_lineplot_markers:
721647
def test_line_no_markers(self):
722648
"""test high-level usage for simple example.
723649
Test is successfull if generated tikz code saves correct amount of lines

0 commit comments

Comments
 (0)