Skip to content

Commit 17f70ba

Browse files
refactoring
1 parent 820696e commit 17f70ba

File tree

1 file changed

+92
-54
lines changed

1 file changed

+92
-54
lines changed

tikzplotlib/_cleanfigure.py

Lines changed: 92 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _get_visual_limits(fighandle, axhandle):
277277
return xLim, yLim
278278

279279

280-
def _replace_data_with_NaN(linehandle, id_replace):
280+
def _replace_data_with_NaN(data, id_replace, is3D):
281281
"""Replaces data at id_replace with NaNs.
282282
283283
:param data: array of x and y data with shape [N, 2]
@@ -287,34 +287,44 @@ def _replace_data_with_NaN(linehandle, id_replace):
287287
:param linehandle: matplotlib line handle object
288288
:type linehandle: object
289289
"""
290-
if _elements(id_replace) == 0:
291-
return
290+
if _isempty(data):
291+
return data
292292

293-
is3D = _lineIs3D(linehandle)
293+
# is3D = _lineIs3D(linehandle)
294294

295+
# if is3D:
296+
# xData, yData, zData = linehandle.get_data_3d()
297+
# zData = zData.copy()
298+
# else:
299+
# xData = linehandle.get_xdata().astype(np.float32)
300+
# yData = linehandle.get_ydata().astype(np.float32)
295301
if is3D:
296-
xData, yData, zData = linehandle.get_data_3d()
297-
zData = zData.copy()
302+
xData, yData, zData = _split_data_3D(data)
298303
else:
299-
xData = linehandle.get_xdata().astype(np.float32)
300-
yData = linehandle.get_ydata().astype(np.float32)
304+
xData, yData = _split_data_2D(data)
301305

302306
xData[id_replace] = np.NaN
303307
yData[id_replace] = np.NaN
304308
if is3D:
305309
zData = zData.copy()
306310
zData[id_replace] = np.NaN
307311

312+
# if is3D:
313+
# # TODO: I don't understand why I need to set both to get tikz code reduction to work
314+
# linehandle.set_data_3d(xData, yData, zData)
315+
# linehandle.set_data(xData, yData)
316+
# else:
317+
# linehandle.set_xdata(xData)
318+
# linehandle.set_ydata(yData)
319+
308320
if is3D:
309-
# TODO: I don't understand why I need to set both to get tikz code reduction to work
310-
linehandle.set_data_3d(xData, yData, zData)
311-
linehandle.set_data(xData, yData)
321+
new_data = _stack_data_3D(xData, yData, zData)
312322
else:
313-
linehandle.set_xdata(xData)
314-
linehandle.set_ydata(yData)
323+
new_data = _stack_data_2D(xData, yData)
324+
return new_data
315325

316326

317-
def _remove_data(linehandle, id_remove):
327+
def _remove_data(data, id_remove, is3D):
318328
"""remove data at id_remove
319329
320330
:param data: array of x and y data with shape [N, 2]
@@ -324,22 +334,25 @@ def _remove_data(linehandle, id_remove):
324334
:param linehandle: matplotlib linehandle object
325335
:type linehandle: object
326336
"""
327-
if _elements(id_remove) == 0:
328-
return
337+
if _isempty(data):
338+
return data
329339

330-
is3D = _lineIs3D(linehandle)
331340
if is3D:
332-
xData, yData, zData = linehandle.get_data_3d()
341+
xData, yData, zData = _split_data_3D(data)
333342
else:
334-
xData = linehandle.get_xdata().astype(np.float32)
335-
yData = linehandle.get_ydata().astype(np.float32)
343+
xData, yData = _split_data_2D(data)
336344

337345
xData = np.delete(xData, id_remove, axis=0)
338346
yData = np.delete(yData, id_remove, axis=0)
339347
if is3D:
340348
zData = np.delete(zData, id_remove, axis=0)
341349

342350
if is3D:
351+
newdata = _stack_data_3D(xData, yData, zData)
352+
else:
353+
newdata = _stack_data_2D(xData, yData)
354+
return newdata
355+
343356
# TODO: I don't understand why I need to set both to get tikz code reduction to work
344357
linehandle.set_data_3d(xData, yData, zData)
345358
linehandle.set_data(xData, yData)
@@ -348,6 +361,26 @@ def _remove_data(linehandle, id_remove):
348361
linehandle.set_ydata(yData)
349362

350363

364+
def _split_data_2D(data):
365+
xData, yData = np.split(data, 2, axis=1)
366+
return xData.reshape((-1,)), yData.reshape((-1,))
367+
368+
369+
def _stack_data_2D(xData, yData):
370+
data = np.stack([xData, yData], axis=1)
371+
return data
372+
373+
374+
def _split_data_3D(data):
375+
xData, yData, zData = np.split(data, 3, axis=1)
376+
return yData.reshape((-1,)), yData.reshape((-1,)), zData.reshape((-1,))
377+
378+
379+
def _stack_data_3D(xData, yData, zData):
380+
data = np.stack([xData, yData, zData], axis=1)
381+
return data
382+
383+
351384
def _diff(x, *args, **kwargs):
352385
"""modification of np.diff(x, *args, **kwargs).
353386
- If x is empty, return np.array([False])
@@ -364,20 +397,11 @@ def _diff(x, *args, **kwargs):
364397
return np.diff(x, *args, **kwargs)
365398

366399

367-
def _remove_NaNs(linehandle):
400+
def _remove_NaNs(data):
368401
"""Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consecutive ones.
369402
370403
:param linehandle: matplotlib linehandle object
371404
"""
372-
is3D = _lineIs3D(linehandle)
373-
if is3D:
374-
xData, yData, zData = linehandle.get_data_3d()
375-
data = np.stack([xData, yData, zData], axis=1)
376-
else:
377-
xData = linehandle.get_xdata().astype(np.float32)
378-
yData = linehandle.get_ydata().astype(np.float32)
379-
data = np.stack([xData, yData], axis=1)
380-
381405
id_nan = np.any(np.isnan(data), axis=1)
382406
id_remove = np.argwhere(id_nan).reshape((-1,))
383407
if _isempty(id_remove):
@@ -394,20 +418,13 @@ def _remove_NaNs(linehandle):
394418

395419
if _isempty(id_first):
396420
# remove entire data
397-
id_remove = np.arange(len(xData))
421+
id_remove = np.arange(len(data))
398422
else:
399423
id_remove = np.concatenate(
400-
[np.arange(0, id_first), id_remove, np.arange(id_last + 1, len(xData))]
424+
[np.arange(0, id_first), id_remove, np.arange(id_last + 1, len(data))]
401425
)
402426
data = np.delete(data, id_remove, axis=0)
403-
404-
if is3D:
405-
# TODO: I don't understand why I need to set both to get tikz code reduction to work
406-
linehandle.set_data_3d(data[:, 0], data[:, 1], data[:, 2])
407-
linehandle.set_data(xData, yData)
408-
else:
409-
linehandle.set_xdata(data[:, 0])
410-
linehandle.set_ydata(data[:, 1])
427+
return data
411428

412429

413430
def _isInBox(data, xLim, yLim):
@@ -443,6 +460,18 @@ def _axIs3D(axhandle):
443460
return hasattr(axhandle, "get_zlim")
444461

445462

463+
def _get_data(linehandle):
464+
is3D = _lineIs3D(linehandle)
465+
if is3D:
466+
xData, yData, zData = linehandle.get_data_3d()
467+
data = _stack_data_3D(xData, yData, zData)
468+
else:
469+
xData = linehandle.get_xdata().astype(np.float32)
470+
yData = linehandle.get_ydata().astype(np.float32)
471+
data = _stack_data_2D(xData, yData)
472+
return data
473+
474+
446475
def _get_visual_data(axhandle, linehandle):
447476
"""Returns the visual representation of the data,
448477
respecting possible log_scaling and projection into the image plane.
@@ -502,7 +531,14 @@ def _isempty(array):
502531
return _elements(array) == 0
503532

504533

505-
def _prune_outside_box(fighandle, axhandle, linehandle):
534+
def _line_has_lines(linehandle):
535+
hasLines = (linehandle.get_linestyle() is not None) and (
536+
linehandle.get_linewidth() > 0.0
537+
)
538+
return hasLines
539+
540+
541+
def _prune_outside_box(xLim, yLim, data, visual_data, is3D, hasLines):
506542
"""Some sections of the line may sit outside of the visible box. Cut those off.
507543
508544
This method is not pure because it updates the linehandle object's data.
@@ -514,28 +550,28 @@ def _prune_outside_box(fighandle, axhandle, linehandle):
514550
:param linehandle: matplotlib line2D handle object
515551
:type linehandle: obj
516552
"""
517-
xData, yData = _get_visual_data(axhandle, linehandle)
553+
# xData, yData = _get_visual_data(axhandle, linehandle)
518554

519-
data = np.stack([xData, yData], axis=1)
555+
# data = np.stack([xData, yData], axis=1)
520556

521-
if _elements(data) == 0:
522-
return
557+
if _elements(visual_data) == 0:
558+
return data
523559

524-
hasLines = (linehandle.get_linestyle() is not None) and (
525-
linehandle.get_linewidth() > 0.0
526-
)
560+
# hasLines = (linehandle.get_linestyle() is not None) and (
561+
# linehandle.get_linewidth() > 0.0
562+
# )
527563

528-
xLim, yLim = _get_visual_limits(fighandle, axhandle)
564+
# xLim, yLim = _get_visual_limits(fighandle, axhandle)
529565

530566
tol = 1.0e-10
531567
relaxedXLim = xLim + np.array([-tol, tol])
532568
relaxedYLim = yLim + np.array([-tol, tol])
533569

534-
dataIsInBox = _isInBox(data, relaxedXLim, relaxedYLim)
570+
dataIsInBox = _isInBox(visual_data, relaxedXLim, relaxedYLim)
535571

536572
shouldPlot = dataIsInBox
537573
if hasLines:
538-
segvis = _segment_visible(data, dataIsInBox, xLim, yLim)
574+
segvis = _segment_visible(visual_data, dataIsInBox, xLim, yLim)
539575
shouldPlot = np.logical_or(
540576
shouldPlot, np.concatenate([np.array([False]).reshape((-1,)), segvis])
541577
)
@@ -558,9 +594,11 @@ def _prune_outside_box(fighandle, axhandle, linehandle):
558594

559595
id_replace = id_remove[idx]
560596
id_remove = id_remove[np.logical_not(idx)]
561-
_replace_data_with_NaN(linehandle, id_replace)
562-
_remove_data(linehandle, id_remove)
563-
_remove_NaNs(linehandle)
597+
598+
data = _replace_data_with_NaN(data, id_replace, is3D)
599+
data = _remove_data(data, id_remove, is3D)
600+
data = _remove_NaNs(data)
601+
return data
564602

565603

566604
def _move_points_closer(fighandle, axhandle, linehandle):

0 commit comments

Comments
 (0)