Skip to content

Commit adf53e1

Browse files
docstring updates
1 parent 60aa008 commit adf53e1

File tree

2 files changed

+57
-61
lines changed

2 files changed

+57
-61
lines changed

test/test_cleanfigure.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,6 @@ def test_pruneOutsideBox3D(self):
7777
plt.close("all")
7878

7979

80-
@pytest.mark.parametrize(
81-
"function, result", [("plot", False), ("step", True)],
82-
)
83-
def test_is_step(function, result):
84-
x = np.linspace(1, 100, 20)
85-
y = np.linspace(1, 100, 20)
86-
87-
with plt.rc_context(rc=RC_PARAMS):
88-
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
89-
if function == "plot":
90-
(l,) = ax.plot(x, y)
91-
elif function == "step":
92-
(l,) = ax.step(x, y)
93-
assert cleanfigure._isStep(l) == result
94-
plt.close("all")
95-
96-
9780
class Test_plottypes:
9881
"""Testing plot types found here https://matplotlib.org/3.1.1/tutorials/introductory/sample_plots.html"""
9982

tikzplotlib/_cleanfigure.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,6 @@
44
import mpl_toolkits
55
from mpl_toolkits import mplot3d
66

7-
# TODO: increase coverage or remove unused functions [!!!]
8-
# TODO: see which test cases the matlab2tikz guys used [!!!]
9-
# TODO: refactoring: consistently use camel_case or _
10-
# TODO: find suitable test cases for remaining functions. [!!]
11-
# TODO: implement remaining functions [!!]
12-
# - simplify stair : plt.step
13-
# -- looks like matlabs stairs plot and matplotlibs plt.step is implemented differently. The data representation is different.
14-
# - there is still a missing code block in movePointsCloser. Maybe find suitable axes limits to get this code block to work
15-
# TODO: make grid of plot types which are working and which not. 2D and 3D
167

178
STEP_DRAW_STYLES = ["steps-pre", "steps-post", "steps-mid"]
189

@@ -207,7 +198,7 @@ def _clean_containers(axes):
207198

208199

209200
def _cleanline(fighandle, axhandle, linehandle, target_resolution, scale_precision):
210-
"""Clean a 2D Line plot figure.
201+
"""Clean a 2D or 3D Line plot figure.
211202
212203
:param fighandle: matplotlib figure object
213204
:param axhandle: matplotlib axes object
@@ -262,14 +253,24 @@ def _cleanline(fighandle, axhandle, linehandle, target_resolution, scale_precisi
262253
def _clean_collections(
263254
fighandle, axhandle, collection, target_resolution, scale_precision
264255
):
265-
import warnings
256+
"""Clean a 2D or 3D collection, i.e. scatter plot.
266257
267-
warnings.warn("Cleaning Path Collections (scatter plot) is not supported yet.")
258+
:param fighandle: matplotlib figure object
259+
:param axhandle: matplotlib axes object
260+
:param linehandle: mplot3d.art3d.Path3DCollection or mpl.collections.PathCollection
261+
:param target_resolution: target resolution of final figure in PPI.
262+
If a scalar integer is provided, it is assumed to be square in both axis.
263+
If a list or an np.array is provided, it is interpreted as [H, W].
264+
By default 600
265+
:type target_resolution: int, list or np.array, optional
266+
:param scalePrecision: scalar value indicating precision when scaling down.
267+
By default 1
268+
:type scalePrecision: float, optional
269+
"""
268270
data, is3D = _get_collection_data(collection)
269271
xLim, yLim = _get_visual_limits(fighandle, axhandle)
270272
visual_data = _get_visual_data(axhandle, data, is3D)
271273

272-
# TODO: not sure if it should be true or false.
273274
hasLines = True
274275

275276
data = _prune_outside_box(xLim, yLim, data, visual_data, is3D, hasLines)
@@ -317,6 +318,8 @@ def _get_visual_limits(fighandle, axhandle):
317318
:type fighandle: mpl.figure.Figure
318319
:param axhandle: handle to matplotlib axes object
319320
:type axhandle: mpl.axes.Axes or mpl_toolkits.mplot3d.axes3d.Axes3D
321+
322+
:returns: (xLim, yLim)
320323
"""
321324
is3D = _axIs3D(axhandle)
322325

@@ -364,18 +367,12 @@ def _replace_data_with_NaN(data, id_replace, is3D):
364367
:type id_replace: np.array
365368
:param linehandle: matplotlib line handle object
366369
:type linehandle: object
370+
371+
:returns: new_data
367372
"""
368373
if _isempty(id_replace):
369374
return data
370375

371-
# is3D = _lineIs3D(linehandle)
372-
373-
# if is3D:
374-
# xData, yData, zData = linehandle.get_data_3d()
375-
# zData = zData.copy()
376-
# else:
377-
# xData = linehandle.get_xdata().astype(np.float32)
378-
# yData = linehandle.get_ydata().astype(np.float32)
379376
if is3D:
380377
xData, yData, zData = _split_data_3D(data)
381378
else:
@@ -387,14 +384,6 @@ def _replace_data_with_NaN(data, id_replace, is3D):
387384
zData = zData.copy()
388385
zData[id_replace] = np.NaN
389386

390-
# if is3D:
391-
# # TODO: I don't understand why I need to set both to get tikz code reduction to work
392-
# linehandle.set_data_3d(xData, yData, zData)
393-
# linehandle.set_data(xData, yData)
394-
# else:
395-
# linehandle.set_xdata(xData)
396-
# linehandle.set_ydata(yData)
397-
398387
if is3D:
399388
new_data = _stack_data_3D(xData, yData, zData)
400389
else:
@@ -411,6 +400,8 @@ def _remove_data(data, id_remove, is3D):
411400
:type id_remove: np.array
412401
:param linehandle: matplotlib linehandle object
413402
:type linehandle: object
403+
404+
:returns: new_data
414405
"""
415406
if _isempty(id_remove):
416407
return data
@@ -446,21 +437,25 @@ def _update_line_data(linehandle, data):
446437

447438

448439
def _split_data_2D(data):
440+
""" data --> xData, yData """
449441
xData, yData = np.split(data, 2, axis=1)
450442
return xData.reshape((-1,)), yData.reshape((-1,))
451443

452444

453445
def _stack_data_2D(xData, yData):
446+
""" xData, yData --> data """
454447
data = np.stack([xData, yData], axis=1)
455448
return data
456449

457450

458451
def _split_data_3D(data):
452+
""" data --> xData, yData, zData """
459453
xData, yData, zData = np.split(data, 3, axis=1)
460454
return xData.reshape((-1,)), yData.reshape((-1,)), zData.reshape((-1,))
461455

462456

463457
def _stack_data_3D(xData, yData, zData):
458+
""" xData, yData, zData --> data """
464459
data = np.stack([xData, yData, zData], axis=1)
465460
return data
466461

@@ -485,6 +480,8 @@ def _remove_NaNs(data):
485480
"""Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consecutive ones.
486481
487482
:param linehandle: matplotlib linehandle object
483+
484+
:returns: data without NaNs
488485
"""
489486
id_nan = np.any(np.isnan(data), axis=1)
490487
id_remove = np.argwhere(id_nan).reshape((-1,))
@@ -520,6 +517,8 @@ def _isInBox(data, xLim, yLim):
520517
:type xLim: list or np.array
521518
:param yLim: y axes limits
522519
:type xLim: list or np.array
520+
521+
:returns: mask
523522
"""
524523
maskX = np.logical_and(data[:, 0] > xLim[0], data[:, 0] < xLim[1])
525524
maskY = np.logical_and(data[:, 1] > yLim[0], data[:, 1] < yLim[1])
@@ -545,6 +544,12 @@ def _axIs3D(axhandle):
545544

546545

547546
def _get_line_data(linehandle):
547+
"""Retrieve 2D or 3D data from line object.
548+
549+
:param linehandle: matplotlib linehandle object
550+
551+
:returns : (data, is3D)
552+
"""
548553
is3D = _lineIs3D(linehandle)
549554
if is3D:
550555
xData, yData, zData = linehandle.get_data_3d()
@@ -581,6 +586,8 @@ def _get_visual_data(axhandle, data, is3D):
581586
:type axhandle: object
582587
:param linehandle: handle for matplotlib line2D object
583588
:type linehandle: object
589+
590+
:returns : visualData
584591
"""
585592
if is3D:
586593
xData, yData, zData = _split_data_3D(data)
@@ -632,6 +639,7 @@ def _isempty(array):
632639

633640

634641
def _line_has_lines(linehandle):
642+
""" check if linestyle is not None and linewidth is larger than 0 """
635643
hasLines = (linehandle.get_linestyle() is not None) and (
636644
linehandle.get_linewidth() > 0.0
637645
)
@@ -649,11 +657,9 @@ def _prune_outside_box(xLim, yLim, data, visual_data, is3D, hasLines):
649657
:type axhandle: obj
650658
:param linehandle: matplotlib line2D handle object
651659
:type linehandle: obj
652-
"""
653-
# xData, yData = _get_visual_data(axhandle, linehandle)
654-
655-
# data = np.stack([xData, yData], axis=1)
656660
661+
:returns: data.
662+
"""
657663
if _elements(visual_data) == 0:
658664
return data
659665

@@ -711,6 +717,8 @@ def _move_points_closer(xLim, yLim, data):
711717
:type axhandle: obj
712718
:param linehandle: matplotlib line handle object
713719
:type linehandle: obj
720+
721+
:returns: data.
714722
"""
715723
# Calculate the extension of the extended box
716724
xWidth = xLim[1] - xLim[0]
@@ -749,6 +757,8 @@ def _insert_data(data, id_insert, dataInsert):
749757
:type id_insert: np.ndarray
750758
:param dataInsert: array of data to insert. Shape [N, 2]
751759
:type dataInsert: np.ndarray
760+
761+
:returns: data.
752762
"""
753763
if _isempty(id_insert):
754764
return data
@@ -788,6 +798,8 @@ def _simplify_line(
788798
If a scalar integer is provided, it is assumed to be square in both axis.
789799
If a list or an np.array is provided, it is interpreted as [H, W]
790800
:type target_resolution: int, list of int or np.array
801+
802+
:returns: data.
791803
"""
792804
if type(target_resolution) not in [list, np.ndarray, np.array]:
793805
if np.isinf(target_resolution) or target_resolution == 0:
@@ -800,8 +812,6 @@ def _simplify_line(
800812
if np.size(xDataVis) <= 2 or np.size(yDataVis) <= 2:
801813
return data
802814

803-
# xLim, yLim = _get_visual_limits(fighandle, axhandle)
804-
805815
# Automatically guess a tol based on the area of the figure and
806816
# the area and resolution of the output
807817
xRange = xLim[1] - xLim[0]
@@ -878,6 +888,7 @@ def _pixelate(x, y, xToPix, yToPix):
878888
:param yToPix: scalar converting x measure to pixel measure in y direction
879889
:type yToPix: float
880890
891+
:returns: mask
881892
"""
882893
mult = 2
883894
dataPixel = np.round(np.stack([x * xToPix * mult, y * yToPix * mult], axis=1))
@@ -907,6 +918,7 @@ def _get_width_height_in_pixels(fighandle, target_resolution):
907918
:param target_resolution: Target resolution in PPI/ DPI. If target_resolution is a scalar, calculate final pixels based on figure width and height.
908919
:type target_resolution: scalar or list or np.array
909920
921+
:returns : H, W
910922
"""
911923
if np.isscalar(target_resolution):
912924
# in matplotlib, the figsize units are always in inches
@@ -1016,17 +1028,12 @@ def _limit_precision(axhandle, data, is3D, alpha):
10161028
:type linehandle: obj
10171029
:param alpha: scalar value indicating precision when scaling down. By default 1
10181030
:type alpha: float
1031+
1032+
:returns : data
10191033
"""
10201034
if alpha <= 0:
10211035
return data
10221036

1023-
# is3D = _lineIs3D(linehandle)
1024-
# if is3D:
1025-
# xData, yData, zData = linehandle.get_data_3d()
1026-
# else:
1027-
# xData = linehandle.get_xdata().astype(np.float32)
1028-
# yData = linehandle.get_ydata().astype(np.float32)
1029-
10301037
if is3D:
10311038
xData, yData, zData = _split_data_3D(data)
10321039
else:
@@ -1079,6 +1086,8 @@ def _segment_visible(data, dataIsInBox, xLim, yLim):
10791086
:type xLim: list, np.array
10801087
:param yLim: y axes limits
10811088
:type yLim: list, np.array
1089+
1090+
:returns : mask
10821091
"""
10831092
n = np.shape(data)[0]
10841093
mask = np.zeros((n - 1, 1)) == 1
@@ -1120,7 +1129,6 @@ def _corners2D(xLim, yLim):
11201129
:param yLim: y limits interval. Shape [2, ]
11211130
:type yLim: np.array
11221131
"""
1123-
11241132
bottomLeft = np.array([xLim[0], yLim[0]])
11251133
topLeft = np.array([xLim[0], yLim[1]])
11261134
bottomRight = np.array([xLim[1], yLim[0]])
@@ -1171,8 +1179,10 @@ def _get_projection_matrix(axhandle):
11711179
11721180
:param axhandle: matplotlib axes handle object
11731181
:type axhandle: obj
1182+
1183+
:returns: Projection matrix P
1184+
:rtype: np.array
11741185
"""
1175-
# TODO: write test
11761186
az = np.deg2rad(axhandle.azim)
11771187
el = np.deg2rad(axhandle.elev)
11781188
rotationZ = np.array(
@@ -1214,6 +1224,9 @@ def _segments_intersect(X1, X2, X3, X4):
12141224
:type X3: np.ndarray
12151225
:param X4: X4
12161226
:type X4: np.ndarray
1227+
1228+
:returns: mask
1229+
:rtype: boolean np.ndarray
12171230
"""
12181231
Lambda = _cross_lines(X1, X2, X3, X4)
12191232

0 commit comments

Comments
 (0)