Skip to content

Commit 67c0a9f

Browse files
implemented pruneOutsideBox functionality
1 parent c6c62af commit 67c0a9f

File tree

1 file changed

+262
-0
lines changed

1 file changed

+262
-0
lines changed

tikzplotlib/cleanfigure.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from matplotlib import pyplot as plt
2+
import numpy as np
3+
4+
5+
def getVisualLimits(fighandle, axhandle):
6+
"""Returns the visual representation of the axis limits (Respecting
7+
possible log_scaling and projection into the image plane)
8+
9+
Parameters
10+
----------
11+
fighandle : obj
12+
handle to matplotlib figure object
13+
axhandle : obj
14+
hande to matplotlib axes object
15+
16+
Returns
17+
-------
18+
np.array
19+
xLim as array of shape [2, ]
20+
np.array
21+
yLim as array of shape [2, ]
22+
"""
23+
# TODO: implement 3D functionality
24+
is3D = False
25+
26+
xLim = np.array(axhandle.get_xlim())
27+
yLim = np.array(axhandle.get_ylim())
28+
if is3D:
29+
zLim = axhandle.get_ylim()
30+
31+
# Check for logarithmic scales
32+
isXlog = axhandle.get_xscale() == "log"
33+
if isXlog:
34+
xLim = np.log10(xLim)
35+
isYLog = axhandle.get_yscale() == "log"
36+
if isYLog:
37+
yLim = np.log10(yLim)
38+
if is3D:
39+
isZLog = axhandle.get_zscale() == "log"
40+
if isZLog:
41+
zLim = np.log10(zLim)
42+
43+
return xLim, yLim
44+
45+
46+
def replaceDataWithNaN(data, id_replace):
47+
"""Replaces data at id_replace with NaNs
48+
49+
Parameters
50+
----------
51+
data : np.ndarray
52+
array of x and y data with shape [N, 2]
53+
id_replace : np.array
54+
array with indices to replace. Shape [K,]
55+
56+
Returns
57+
-------
58+
np.ndarray
59+
data with replace values
60+
"""
61+
if elements(id_replace) == 0:
62+
return data
63+
64+
# TODO: add 3D compatibility
65+
is3D = False
66+
data = data.astype(np.float32)
67+
xData, yData = np.split(data, 2, 1)
68+
xData[id_replace] = np.NaN
69+
yData[id_replace] = np.NaN
70+
return np.concatenate([xData, yData], axis=1)
71+
72+
73+
def removeData(data, id_remove):
74+
"""remove data at id_remove
75+
76+
Parameters
77+
----------
78+
data : np.ndarray
79+
array of x and y data with shape [N, 2]
80+
id_remove : np.array
81+
array of x and y data with shape [N, 2]
82+
83+
Returns
84+
-------
85+
np.ndarray
86+
new data array
87+
"""
88+
if elements(id_remove) == 0:
89+
return data
90+
91+
# TODO: add 3D compatibility
92+
is3D = False
93+
xData, yData = np.split(data, 2, 1)
94+
xData = np.delete(xData, id_remove, axis=0)
95+
yData = np.delete(yData, id_remove, axis=0)
96+
return np.concatenate([xData, yData], axis=1)
97+
98+
99+
def removeNaNs(data):
100+
"""Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consequtive ones.
101+
102+
Parameters
103+
----------
104+
data : np.ndarray
105+
array of x and y data with shape [N, 2]
106+
107+
Returns
108+
-------
109+
np.ndarray
110+
new data array
111+
"""
112+
# TODO: implement 3D functionality
113+
xData, yData = np.split(data, 2, 1)
114+
id_nan = np.any(np.isnan(data), axis=1)
115+
id_remove = np.argwhere(id_nan).reshape((-1,))
116+
id_remove = id_remove[
117+
np.concatenate(
118+
[np.array([True,]).reshape((-1,)), np.diff(id_remove, axis=0) == 1]
119+
)
120+
]
121+
122+
id_first = np.argwhere(np.logical_not(id_nan))[0]
123+
id_last = np.argwhere(np.logical_not(id_nan))[-1]
124+
125+
if elements(id_first) == 0:
126+
id_remove = np.arange(len(xData))
127+
else:
128+
id_remove = np.concatenate(
129+
[np.arange(1, id_first - 1), id_remove, np.arange(id_last + 1, len(xData))]
130+
)
131+
xData = np.delete(xData, id_remove, axis=0)
132+
yData = np.delete(yData, id_remove, axis=0)
133+
return np.concatenate([xData, yData], axis=1)
134+
135+
return data
136+
137+
138+
def isInBox(data, xLim, yLim):
139+
"""Returns a mask that indicates, whether a data point is within the limits
140+
141+
Parameters
142+
----------
143+
data : np.ndarray
144+
N x 2 array of data points. data[:, 0] are x coordinates, data[:, 1] are y coordinates
145+
xLim : np.array
146+
array with x limits. Shape [2, ]
147+
yLim : np.array
148+
array with y limits. Shape [2, ]
149+
"""
150+
maskX = np.logical_and(data[:, 0] > xLim[0], data[:, 0] < xLim[1])
151+
maskY = np.logical_and(data[:, 1] > yLim[0], data[:, 1] < yLim[1])
152+
mask = maskX & maskY
153+
return mask
154+
155+
156+
def getVisualData(axhandle, linehandle):
157+
"""Returns the visual representation of the data (Respecting possible log_scaling and projection into the image plane)
158+
159+
Parameters
160+
----------
161+
axhandle : obj
162+
handle for matplotlib axis object
163+
linehandle : obj
164+
handle for matplotlib line2D object
165+
166+
Returns
167+
-------
168+
np.ndarray
169+
xData with shape [N, 1]
170+
np.ndarray
171+
yData with shape [N, 1]
172+
"""
173+
is3D = False
174+
175+
xData = linehandle.get_xdata()
176+
yData = linehandle.get_ydata()
177+
if is3D:
178+
zData = linehandle.get_zdata()
179+
180+
isXlog = axhandle.get_xscale() == "log"
181+
if isXlog:
182+
xData = np.log10(xData)
183+
isYlog = axhandle.get_yscale() == "log"
184+
if isYlog:
185+
yData = np.log10(yData)
186+
if is3D:
187+
isZlog = axhandle.get_zscale() == "log"
188+
if isZlog:
189+
zData = np.log10(zData)
190+
191+
xData = np.reshape(xData, (-1,))
192+
yData = np.reshape(yData, (-1,))
193+
return xData, yData
194+
195+
196+
def elements(array):
197+
"""check if array has elements.
198+
https://stackoverflow.com/questions/11295609/how-can-i-check-whether-the-numpy-array-is-empty-or-not
199+
"""
200+
return array.ndim and array.size
201+
202+
203+
def pruneOutsideBox(fighandle, axhandle, linehandle):
204+
"""Some sections of the line may sit outside of the visible box. Cut those off.
205+
206+
This method is not pure because it updates the linehandle object's data.
207+
208+
Parameters
209+
----------
210+
fighandle : obj
211+
matplotlib figure handle object
212+
axhandle : obj
213+
matplotlib axes handle object
214+
linehandle : obj
215+
matplotlib line2D handle object
216+
217+
Returns
218+
-------
219+
"""
220+
xData, yData = getVisualData(axhandle, linehandle)
221+
222+
data = np.stack([xData, yData], axis=1)
223+
224+
if elements(data) == 0:
225+
return
226+
227+
hasLines = (linehandle.get_linestyle() is not None) and (
228+
linehandle.get_linewidth() > 0.0
229+
)
230+
231+
xLim, yLim = getVisualLimits(fighandle, axhandle)
232+
233+
tol = 1.0e-10
234+
relaxedXLim = xLim + np.array([-tol, tol])
235+
relaxedYLim = yLim + np.array([-tol, tol])
236+
237+
dataIsInBox = isInBox(data, relaxedXLim, relaxedYLim)
238+
239+
shouldPlot = dataIsInBox
240+
if hasLines:
241+
pass
242+
# TODO: adapt this snippet from matlab2tikz
243+
# segvis = segmentVisible(data, dataIsInBox, xLim, yLim)
244+
# shouldPlot = shouldPlot | [false; segvis] | [segvis; false];
245+
246+
if not np.all(shouldPlot):
247+
id_remove = np.argwhere(np.logical_not(shouldPlot))
248+
249+
# If there are consecutive data points to be removed, only replace
250+
# the first one by a NaN. Consecutive data points have
251+
# diff(id_remove)==1, so replace diff(id_remove)>1 by NaN and remove
252+
# the rest
253+
idx = np.diff(id_remove, axis=0) > 1
254+
idx = np.concatenate([np.array([True,]).reshape((-1, 1)), idx], axis=0)
255+
256+
id_replace = id_remove[idx]
257+
id_remove = id_remove[np.logical_not(idx)]
258+
data = replaceDataWithNaN(data, id_replace)
259+
data = removeData(data, id_remove)
260+
data = removeNaNs(data)
261+
linehandle.set_xdata(data[:, 0])
262+
linehandle.set_ydata(data[:, 1])

0 commit comments

Comments
 (0)