4
4
import mpl_toolkits
5
5
from mpl_toolkits import mplot3d
6
6
7
- # TODO: increase coverage or remove unused functions [!!!]
8
- # TODO: see which test cases the matlab2tikz guys used [!!!]
9
- # TODO: refactoring: consistently use camel_case or _
10
- # TODO: find suitable test cases for remaining functions. [!!]
11
- # TODO: implement remaining functions [!!]
12
- # - simplify stair : plt.step
13
- # -- looks like matlabs stairs plot and matplotlibs plt.step is implemented differently. The data representation is different.
14
- # - there is still a missing code block in movePointsCloser. Maybe find suitable axes limits to get this code block to work
15
- # TODO: make grid of plot types which are working and which not. 2D and 3D
16
7
17
8
STEP_DRAW_STYLES = ["steps-pre" , "steps-post" , "steps-mid" ]
18
9
@@ -207,7 +198,7 @@ def _clean_containers(axes):
207
198
208
199
209
200
def _cleanline (fighandle , axhandle , linehandle , target_resolution , scale_precision ):
210
- """Clean a 2D Line plot figure.
201
+ """Clean a 2D or 3D Line plot figure.
211
202
212
203
:param fighandle: matplotlib figure object
213
204
:param axhandle: matplotlib axes object
@@ -262,14 +253,24 @@ def _cleanline(fighandle, axhandle, linehandle, target_resolution, scale_precisi
262
253
def _clean_collections (
263
254
fighandle , axhandle , collection , target_resolution , scale_precision
264
255
):
265
- import warnings
256
+ """Clean a 2D or 3D collection, i.e. scatter plot.
266
257
267
- warnings .warn ("Cleaning Path Collections (scatter plot) is not supported yet." )
258
+ :param fighandle: matplotlib figure object
259
+ :param axhandle: matplotlib axes object
260
+ :param linehandle: mplot3d.art3d.Path3DCollection or mpl.collections.PathCollection
261
+ :param target_resolution: target resolution of final figure in PPI.
262
+ If a scalar integer is provided, it is assumed to be square in both axis.
263
+ If a list or an np.array is provided, it is interpreted as [H, W].
264
+ By default 600
265
+ :type target_resolution: int, list or np.array, optional
266
+ :param scalePrecision: scalar value indicating precision when scaling down.
267
+ By default 1
268
+ :type scalePrecision: float, optional
269
+ """
268
270
data , is3D = _get_collection_data (collection )
269
271
xLim , yLim = _get_visual_limits (fighandle , axhandle )
270
272
visual_data = _get_visual_data (axhandle , data , is3D )
271
273
272
- # TODO: not sure if it should be true or false.
273
274
hasLines = True
274
275
275
276
data = _prune_outside_box (xLim , yLim , data , visual_data , is3D , hasLines )
@@ -317,6 +318,8 @@ def _get_visual_limits(fighandle, axhandle):
317
318
:type fighandle: mpl.figure.Figure
318
319
:param axhandle: handle to matplotlib axes object
319
320
:type axhandle: mpl.axes.Axes or mpl_toolkits.mplot3d.axes3d.Axes3D
321
+
322
+ :returns: (xLim, yLim)
320
323
"""
321
324
is3D = _axIs3D (axhandle )
322
325
@@ -364,18 +367,12 @@ def _replace_data_with_NaN(data, id_replace, is3D):
364
367
:type id_replace: np.array
365
368
:param linehandle: matplotlib line handle object
366
369
:type linehandle: object
370
+
371
+ :returns: new_data
367
372
"""
368
373
if _isempty (id_replace ):
369
374
return data
370
375
371
- # is3D = _lineIs3D(linehandle)
372
-
373
- # if is3D:
374
- # xData, yData, zData = linehandle.get_data_3d()
375
- # zData = zData.copy()
376
- # else:
377
- # xData = linehandle.get_xdata().astype(np.float32)
378
- # yData = linehandle.get_ydata().astype(np.float32)
379
376
if is3D :
380
377
xData , yData , zData = _split_data_3D (data )
381
378
else :
@@ -387,14 +384,6 @@ def _replace_data_with_NaN(data, id_replace, is3D):
387
384
zData = zData .copy ()
388
385
zData [id_replace ] = np .NaN
389
386
390
- # if is3D:
391
- # # TODO: I don't understand why I need to set both to get tikz code reduction to work
392
- # linehandle.set_data_3d(xData, yData, zData)
393
- # linehandle.set_data(xData, yData)
394
- # else:
395
- # linehandle.set_xdata(xData)
396
- # linehandle.set_ydata(yData)
397
-
398
387
if is3D :
399
388
new_data = _stack_data_3D (xData , yData , zData )
400
389
else :
@@ -411,6 +400,8 @@ def _remove_data(data, id_remove, is3D):
411
400
:type id_remove: np.array
412
401
:param linehandle: matplotlib linehandle object
413
402
:type linehandle: object
403
+
404
+ :returns: new_data
414
405
"""
415
406
if _isempty (id_remove ):
416
407
return data
@@ -446,21 +437,25 @@ def _update_line_data(linehandle, data):
446
437
447
438
448
439
def _split_data_2D (data ):
440
+ """ data --> xData, yData """
449
441
xData , yData = np .split (data , 2 , axis = 1 )
450
442
return xData .reshape ((- 1 ,)), yData .reshape ((- 1 ,))
451
443
452
444
453
445
def _stack_data_2D (xData , yData ):
446
+ """ xData, yData --> data """
454
447
data = np .stack ([xData , yData ], axis = 1 )
455
448
return data
456
449
457
450
458
451
def _split_data_3D (data ):
452
+ """ data --> xData, yData, zData """
459
453
xData , yData , zData = np .split (data , 3 , axis = 1 )
460
454
return xData .reshape ((- 1 ,)), yData .reshape ((- 1 ,)), zData .reshape ((- 1 ,))
461
455
462
456
463
457
def _stack_data_3D (xData , yData , zData ):
458
+ """ xData, yData, zData --> data """
464
459
data = np .stack ([xData , yData , zData ], axis = 1 )
465
460
return data
466
461
@@ -485,6 +480,8 @@ def _remove_NaNs(data):
485
480
"""Removes superflous NaNs in the data, i.e. those at the end/beginning of the data and consecutive ones.
486
481
487
482
:param linehandle: matplotlib linehandle object
483
+
484
+ :returns: data without NaNs
488
485
"""
489
486
id_nan = np .any (np .isnan (data ), axis = 1 )
490
487
id_remove = np .argwhere (id_nan ).reshape ((- 1 ,))
@@ -520,6 +517,8 @@ def _isInBox(data, xLim, yLim):
520
517
:type xLim: list or np.array
521
518
:param yLim: y axes limits
522
519
:type xLim: list or np.array
520
+
521
+ :returns: mask
523
522
"""
524
523
maskX = np .logical_and (data [:, 0 ] > xLim [0 ], data [:, 0 ] < xLim [1 ])
525
524
maskY = np .logical_and (data [:, 1 ] > yLim [0 ], data [:, 1 ] < yLim [1 ])
@@ -545,6 +544,12 @@ def _axIs3D(axhandle):
545
544
546
545
547
546
def _get_line_data (linehandle ):
547
+ """Retrieve 2D or 3D data from line object.
548
+
549
+ :param linehandle: matplotlib linehandle object
550
+
551
+ :returns : (data, is3D)
552
+ """
548
553
is3D = _lineIs3D (linehandle )
549
554
if is3D :
550
555
xData , yData , zData = linehandle .get_data_3d ()
@@ -581,6 +586,8 @@ def _get_visual_data(axhandle, data, is3D):
581
586
:type axhandle: object
582
587
:param linehandle: handle for matplotlib line2D object
583
588
:type linehandle: object
589
+
590
+ :returns : visualData
584
591
"""
585
592
if is3D :
586
593
xData , yData , zData = _split_data_3D (data )
@@ -632,6 +639,7 @@ def _isempty(array):
632
639
633
640
634
641
def _line_has_lines (linehandle ):
642
+ """ check if linestyle is not None and linewidth is larger than 0 """
635
643
hasLines = (linehandle .get_linestyle () is not None ) and (
636
644
linehandle .get_linewidth () > 0.0
637
645
)
@@ -649,11 +657,9 @@ def _prune_outside_box(xLim, yLim, data, visual_data, is3D, hasLines):
649
657
:type axhandle: obj
650
658
:param linehandle: matplotlib line2D handle object
651
659
:type linehandle: obj
652
- """
653
- # xData, yData = _get_visual_data(axhandle, linehandle)
654
-
655
- # data = np.stack([xData, yData], axis=1)
656
660
661
+ :returns: data.
662
+ """
657
663
if _elements (visual_data ) == 0 :
658
664
return data
659
665
@@ -711,6 +717,8 @@ def _move_points_closer(xLim, yLim, data):
711
717
:type axhandle: obj
712
718
:param linehandle: matplotlib line handle object
713
719
:type linehandle: obj
720
+
721
+ :returns: data.
714
722
"""
715
723
# Calculate the extension of the extended box
716
724
xWidth = xLim [1 ] - xLim [0 ]
@@ -749,6 +757,8 @@ def _insert_data(data, id_insert, dataInsert):
749
757
:type id_insert: np.ndarray
750
758
:param dataInsert: array of data to insert. Shape [N, 2]
751
759
:type dataInsert: np.ndarray
760
+
761
+ :returns: data.
752
762
"""
753
763
if _isempty (id_insert ):
754
764
return data
@@ -788,6 +798,8 @@ def _simplify_line(
788
798
If a scalar integer is provided, it is assumed to be square in both axis.
789
799
If a list or an np.array is provided, it is interpreted as [H, W]
790
800
:type target_resolution: int, list of int or np.array
801
+
802
+ :returns: data.
791
803
"""
792
804
if type (target_resolution ) not in [list , np .ndarray , np .array ]:
793
805
if np .isinf (target_resolution ) or target_resolution == 0 :
@@ -800,8 +812,6 @@ def _simplify_line(
800
812
if np .size (xDataVis ) <= 2 or np .size (yDataVis ) <= 2 :
801
813
return data
802
814
803
- # xLim, yLim = _get_visual_limits(fighandle, axhandle)
804
-
805
815
# Automatically guess a tol based on the area of the figure and
806
816
# the area and resolution of the output
807
817
xRange = xLim [1 ] - xLim [0 ]
@@ -878,6 +888,7 @@ def _pixelate(x, y, xToPix, yToPix):
878
888
:param yToPix: scalar converting x measure to pixel measure in y direction
879
889
:type yToPix: float
880
890
891
+ :returns: mask
881
892
"""
882
893
mult = 2
883
894
dataPixel = np .round (np .stack ([x * xToPix * mult , y * yToPix * mult ], axis = 1 ))
@@ -907,6 +918,7 @@ def _get_width_height_in_pixels(fighandle, target_resolution):
907
918
:param target_resolution: Target resolution in PPI/ DPI. If target_resolution is a scalar, calculate final pixels based on figure width and height.
908
919
:type target_resolution: scalar or list or np.array
909
920
921
+ :returns : H, W
910
922
"""
911
923
if np .isscalar (target_resolution ):
912
924
# in matplotlib, the figsize units are always in inches
@@ -1016,17 +1028,12 @@ def _limit_precision(axhandle, data, is3D, alpha):
1016
1028
:type linehandle: obj
1017
1029
:param alpha: scalar value indicating precision when scaling down. By default 1
1018
1030
:type alpha: float
1031
+
1032
+ :returns : data
1019
1033
"""
1020
1034
if alpha <= 0 :
1021
1035
return data
1022
1036
1023
- # is3D = _lineIs3D(linehandle)
1024
- # if is3D:
1025
- # xData, yData, zData = linehandle.get_data_3d()
1026
- # else:
1027
- # xData = linehandle.get_xdata().astype(np.float32)
1028
- # yData = linehandle.get_ydata().astype(np.float32)
1029
-
1030
1037
if is3D :
1031
1038
xData , yData , zData = _split_data_3D (data )
1032
1039
else :
@@ -1079,6 +1086,8 @@ def _segment_visible(data, dataIsInBox, xLim, yLim):
1079
1086
:type xLim: list, np.array
1080
1087
:param yLim: y axes limits
1081
1088
:type yLim: list, np.array
1089
+
1090
+ :returns : mask
1082
1091
"""
1083
1092
n = np .shape (data )[0 ]
1084
1093
mask = np .zeros ((n - 1 , 1 )) == 1
@@ -1120,7 +1129,6 @@ def _corners2D(xLim, yLim):
1120
1129
:param yLim: y limits interval. Shape [2, ]
1121
1130
:type yLim: np.array
1122
1131
"""
1123
-
1124
1132
bottomLeft = np .array ([xLim [0 ], yLim [0 ]])
1125
1133
topLeft = np .array ([xLim [0 ], yLim [1 ]])
1126
1134
bottomRight = np .array ([xLim [1 ], yLim [0 ]])
@@ -1171,8 +1179,10 @@ def _get_projection_matrix(axhandle):
1171
1179
1172
1180
:param axhandle: matplotlib axes handle object
1173
1181
:type axhandle: obj
1182
+
1183
+ :returns: Projection matrix P
1184
+ :rtype: np.array
1174
1185
"""
1175
- # TODO: write test
1176
1186
az = np .deg2rad (axhandle .azim )
1177
1187
el = np .deg2rad (axhandle .elev )
1178
1188
rotationZ = np .array (
@@ -1214,6 +1224,9 @@ def _segments_intersect(X1, X2, X3, X4):
1214
1224
:type X3: np.ndarray
1215
1225
:param X4: X4
1216
1226
:type X4: np.ndarray
1227
+
1228
+ :returns: mask
1229
+ :rtype: boolean np.ndarray
1217
1230
"""
1218
1231
Lambda = _cross_lines (X1 , X2 , X3 , X4 )
1219
1232
0 commit comments