44import mpl_toolkits
55from mpl_toolkits import mplot3d
66
7- # TODO: increase coverage or remove unused functions [!!!]
8- # TODO: see which test cases the matlab2tikz guys used [!!!]
9- # TODO: refactoring: consistently use camel_case or _
10- # TODO: find suitable test cases for remaining functions. [!!]
11- # TODO: implement remaining functions [!!]
12- # - simplify stair : plt.step
13- # -- looks like matlabs stairs plot and matplotlibs plt.step is implemented differently. The data representation is different.
14- # - there is still a missing code block in movePointsCloser. Maybe find suitable axes limits to get this code block to work
15- # TODO: make grid of plot types which are working and which not. 2D and 3D
167
178STEP_DRAW_STYLES = ["steps-pre" , "steps-post" , "steps-mid" ]
189
@@ -207,7 +198,7 @@ def _clean_containers(axes):
207198
208199
209200def _cleanline (fighandle , axhandle , linehandle , target_resolution , scale_precision ):
210- """Clean a 2D Line plot figure.
201+ """Clean a 2D or 3D Line plot figure.
211202
212203 :param fighandle: matplotlib figure object
213204 :param axhandle: matplotlib axes object
@@ -262,14 +253,24 @@ def _cleanline(fighandle, axhandle, linehandle, target_resolution, scale_precisi
262253def _clean_collections (
263254 fighandle , axhandle , collection , target_resolution , scale_precision
264255):
265- import warnings
256+ """Clean a 2D or 3D collection, i.e. scatter plot.
266257
267- warnings .warn ("Cleaning Path Collections (scatter plot) is not supported yet." )
258+ :param fighandle: matplotlib figure object
259+ :param axhandle: matplotlib axes object
260+ :param linehandle: mplot3d.art3d.Path3DCollection or mpl.collections.PathCollection
261+ :param target_resolution: target resolution of final figure in PPI.
262+ If a scalar integer is provided, it is assumed to be square in both axis.
263+ If a list or an np.array is provided, it is interpreted as [H, W].
264+ By default 600
265+ :type target_resolution: int, list or np.array, optional
266+ :param scalePrecision: scalar value indicating precision when scaling down.
267+ By default 1
268+ :type scalePrecision: float, optional
269+ """
268270 data , is3D = _get_collection_data (collection )
269271 xLim , yLim = _get_visual_limits (fighandle , axhandle )
270272 visual_data = _get_visual_data (axhandle , data , is3D )
271273
272- # TODO: not sure if it should be true or false.
273274 hasLines = True
274275
275276 data = _prune_outside_box (xLim , yLim , data , visual_data , is3D , hasLines )
@@ -317,6 +318,8 @@ def _get_visual_limits(fighandle, axhandle):
317318 :type fighandle: mpl.figure.Figure
318319 :param axhandle: handle to matplotlib axes object
319320 :type axhandle: mpl.axes.Axes or mpl_toolkits.mplot3d.axes3d.Axes3D
321+
322+ :returns: (xLim, yLim)
320323 """
321324 is3D = _axIs3D (axhandle )
322325
@@ -364,18 +367,12 @@ def _replace_data_with_NaN(data, id_replace, is3D):
364367 :type id_replace: np.array
365368 :param linehandle: matplotlib line handle object
366369 :type linehandle: object
370+
371+ :returns: new_data
367372 """
368373 if _isempty (id_replace ):
369374 return data
370375
371- # is3D = _lineIs3D(linehandle)
372-
373- # if is3D:
374- # xData, yData, zData = linehandle.get_data_3d()
375- # zData = zData.copy()
376- # else:
377- # xData = linehandle.get_xdata().astype(np.float32)
378- # yData = linehandle.get_ydata().astype(np.float32)
379376 if is3D :
380377 xData , yData , zData = _split_data_3D (data )
381378 else :
@@ -387,14 +384,6 @@ def _replace_data_with_NaN(data, id_replace, is3D):
387384 zData = zData .copy ()
388385 zData [id_replace ] = np .NaN
389386
390- # if is3D:
391- # # TODO: I don't understand why I need to set both to get tikz code reduction to work
392- # linehandle.set_data_3d(xData, yData, zData)
393- # linehandle.set_data(xData, yData)
394- # else:
395- # linehandle.set_xdata(xData)
396- # linehandle.set_ydata(yData)
397-
398387 if is3D :
399388 new_data = _stack_data_3D (xData , yData , zData )
400389 else :
@@ -411,6 +400,8 @@ def _remove_data(data, id_remove, is3D):
411400 :type id_remove: np.array
412401 :param linehandle: matplotlib linehandle object
413402 :type linehandle: object
403+
404+ :returns: new_data
414405 """
415406 if _isempty (id_remove ):
416407 return data
@@ -446,21 +437,25 @@ def _update_line_data(linehandle, data):
446437
447438
448439def _split_data_2D (data ):
440+ """ data --> xData, yData """
449441 xData , yData = np .split (data , 2 , axis = 1 )
450442 return xData .reshape ((- 1 ,)), yData .reshape ((- 1 ,))
451443
452444
453445def _stack_data_2D (xData , yData ):
446+ """ xData, yData --> data """
454447 data = np .stack ([xData , yData ], axis = 1 )
455448 return data
456449
457450
458451def _split_data_3D (data ):
452+ """ data --> xData, yData, zData """
459453 xData , yData , zData = np .split (data , 3 , axis = 1 )
460454 return xData .reshape ((- 1 ,)), yData .reshape ((- 1 ,)), zData .reshape ((- 1 ,))
461455
462456
463457def _stack_data_3D (xData , yData , zData ):
458+ """ xData, yData, zData --> data """
464459 data = np .stack ([xData , yData , zData ], axis = 1 )
465460 return data
466461
@@ -485,6 +480,8 @@ def _remove_NaNs(data):
485480 """Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consecutive ones.
486481
487482 :param linehandle: matplotlib linehandle object
483+
484+ :returns: data without NaNs
488485 """
489486 id_nan = np .any (np .isnan (data ), axis = 1 )
490487 id_remove = np .argwhere (id_nan ).reshape ((- 1 ,))
@@ -520,6 +517,8 @@ def _isInBox(data, xLim, yLim):
520517 :type xLim: list or np.array
521518 :param yLim: y axes limits
522519 :type xLim: list or np.array
520+
521+ :returns: mask
523522 """
524523 maskX = np .logical_and (data [:, 0 ] > xLim [0 ], data [:, 0 ] < xLim [1 ])
525524 maskY = np .logical_and (data [:, 1 ] > yLim [0 ], data [:, 1 ] < yLim [1 ])
@@ -545,6 +544,12 @@ def _axIs3D(axhandle):
545544
546545
547546def _get_line_data (linehandle ):
547+ """Retrieve 2D or 3D data from line object.
548+
549+ :param linehandle: matplotlib linehandle object
550+
551+ :returns : (data, is3D)
552+ """
548553 is3D = _lineIs3D (linehandle )
549554 if is3D :
550555 xData , yData , zData = linehandle .get_data_3d ()
@@ -581,6 +586,8 @@ def _get_visual_data(axhandle, data, is3D):
581586 :type axhandle: object
582587 :param linehandle: handle for matplotlib line2D object
583588 :type linehandle: object
589+
590+ :returns : visualData
584591 """
585592 if is3D :
586593 xData , yData , zData = _split_data_3D (data )
@@ -632,6 +639,7 @@ def _isempty(array):
632639
633640
634641def _line_has_lines (linehandle ):
642+ """ check if linestyle is not None and linewidth is larger than 0 """
635643 hasLines = (linehandle .get_linestyle () is not None ) and (
636644 linehandle .get_linewidth () > 0.0
637645 )
@@ -649,11 +657,9 @@ def _prune_outside_box(xLim, yLim, data, visual_data, is3D, hasLines):
649657 :type axhandle: obj
650658 :param linehandle: matplotlib line2D handle object
651659 :type linehandle: obj
652- """
653- # xData, yData = _get_visual_data(axhandle, linehandle)
654-
655- # data = np.stack([xData, yData], axis=1)
656660
661+ :returns: data.
662+ """
657663 if _elements (visual_data ) == 0 :
658664 return data
659665
@@ -711,6 +717,8 @@ def _move_points_closer(xLim, yLim, data):
711717 :type axhandle: obj
712718 :param linehandle: matplotlib line handle object
713719 :type linehandle: obj
720+
721+ :returns: data.
714722 """
715723 # Calculate the extension of the extended box
716724 xWidth = xLim [1 ] - xLim [0 ]
@@ -749,6 +757,8 @@ def _insert_data(data, id_insert, dataInsert):
749757 :type id_insert: np.ndarray
750758 :param dataInsert: array of data to insert. Shape [N, 2]
751759 :type dataInsert: np.ndarray
760+
761+ :returns: data.
752762 """
753763 if _isempty (id_insert ):
754764 return data
@@ -788,6 +798,8 @@ def _simplify_line(
788798 If a scalar integer is provided, it is assumed to be square in both axis.
789799 If a list or an np.array is provided, it is interpreted as [H, W]
790800 :type target_resolution: int, list of int or np.array
801+
802+ :returns: data.
791803 """
792804 if type (target_resolution ) not in [list , np .ndarray , np .array ]:
793805 if np .isinf (target_resolution ) or target_resolution == 0 :
@@ -800,8 +812,6 @@ def _simplify_line(
800812 if np .size (xDataVis ) <= 2 or np .size (yDataVis ) <= 2 :
801813 return data
802814
803- # xLim, yLim = _get_visual_limits(fighandle, axhandle)
804-
805815 # Automatically guess a tol based on the area of the figure and
806816 # the area and resolution of the output
807817 xRange = xLim [1 ] - xLim [0 ]
@@ -878,6 +888,7 @@ def _pixelate(x, y, xToPix, yToPix):
878888 :param yToPix: scalar converting x measure to pixel measure in y direction
879889 :type yToPix: float
880890
891+ :returns: mask
881892 """
882893 mult = 2
883894 dataPixel = np .round (np .stack ([x * xToPix * mult , y * yToPix * mult ], axis = 1 ))
@@ -907,6 +918,7 @@ def _get_width_height_in_pixels(fighandle, target_resolution):
907918 :param target_resolution: Target resolution in PPI/ DPI. If target_resolution is a scalar, calculate final pixels based on figure width and height.
908919 :type target_resolution: scalar or list or np.array
909920
921+ :returns : H, W
910922 """
911923 if np .isscalar (target_resolution ):
912924 # in matplotlib, the figsize units are always in inches
@@ -1016,17 +1028,12 @@ def _limit_precision(axhandle, data, is3D, alpha):
10161028 :type linehandle: obj
10171029 :param alpha: scalar value indicating precision when scaling down. By default 1
10181030 :type alpha: float
1031+
1032+ :returns : data
10191033 """
10201034 if alpha <= 0 :
10211035 return data
10221036
1023- # is3D = _lineIs3D(linehandle)
1024- # if is3D:
1025- # xData, yData, zData = linehandle.get_data_3d()
1026- # else:
1027- # xData = linehandle.get_xdata().astype(np.float32)
1028- # yData = linehandle.get_ydata().astype(np.float32)
1029-
10301037 if is3D :
10311038 xData , yData , zData = _split_data_3D (data )
10321039 else :
@@ -1079,6 +1086,8 @@ def _segment_visible(data, dataIsInBox, xLim, yLim):
10791086 :type xLim: list, np.array
10801087 :param yLim: y axes limits
10811088 :type yLim: list, np.array
1089+
1090+ :returns : mask
10821091 """
10831092 n = np .shape (data )[0 ]
10841093 mask = np .zeros ((n - 1 , 1 )) == 1
@@ -1120,7 +1129,6 @@ def _corners2D(xLim, yLim):
11201129 :param yLim: y limits interval. Shape [2, ]
11211130 :type yLim: np.array
11221131 """
1123-
11241132 bottomLeft = np .array ([xLim [0 ], yLim [0 ]])
11251133 topLeft = np .array ([xLim [0 ], yLim [1 ]])
11261134 bottomRight = np .array ([xLim [1 ], yLim [0 ]])
@@ -1171,8 +1179,10 @@ def _get_projection_matrix(axhandle):
11711179
11721180 :param axhandle: matplotlib axes handle object
11731181 :type axhandle: obj
1182+
1183+ :returns: Projection matrix P
1184+ :rtype: np.array
11741185 """
1175- # TODO: write test
11761186 az = np .deg2rad (axhandle .azim )
11771187 el = np .deg2rad (axhandle .elev )
11781188 rotationZ = np .array (
@@ -1214,6 +1224,9 @@ def _segments_intersect(X1, X2, X3, X4):
12141224 :type X3: np.ndarray
12151225 :param X4: X4
12161226 :type X4: np.ndarray
1227+
1228+ :returns: mask
1229+ :rtype: boolean np.ndarray
12171230 """
12181231 Lambda = _cross_lines (X1 , X2 , X3 , X4 )
12191232
0 commit comments