|
2 | 2 |
|
3 | 3 | config = Config() |
4 | 4 |
|
5 | | -def plot_curve(*curves, savename=None, saveextension=None): |
| 5 | +def plot_curve(*curves, vertex_markings=True, savename=None, saveextension=None, return_fig=False, legend=True): |
6 | 6 | import matplotlib.pyplot as plt |
| 7 | + from mpl_toolkits.mplot3d import Axes3D |
7 | 8 | max_compl = 1 |
| 9 | + max_dim = 1 |
| 10 | + fig = plt.figure() |
| 11 | + ax = None |
8 | 12 | for curve in curves: |
9 | 13 | if isinstance(curve, backend.Curve): |
10 | 14 | max_compl = max(max_compl, curve.complexity) |
| 15 | + max_dim = max(max_dim, curve.dimensions) |
11 | 16 | elif isinstance(curve, backend.Curves): |
12 | 17 | for curv in curve: |
13 | 18 | max_compl = max(max_compl, curv.complexity) |
| 19 | + max_dim = max(max_dim, curv.dimensions) |
14 | 20 | elif isinstance(curve, backend.Clustering_Result): |
15 | 21 | for curv in curve: |
16 | 22 | max_compl = max(max_compl, curv.complexity) |
| 23 | + max_dim = max(max_dim, curv.dimensions) |
| 24 | + if max_dim >= 3: |
| 25 | + ax = fig.gca(projection='3d') |
| 26 | + else: |
| 27 | + ax = fig.gca() |
17 | 28 | for curve in curves: |
18 | 29 | if isinstance(curve, backend.Curve): |
19 | | - if curve.dimensions >= 2: |
20 | | - p = plt.plot(curve.values[:, 0], curve.values[:, 1], '--o', label = curve.name, markersize = 7, markevery = curve.complexity) |
21 | | - plt.plot(curve.values[1:, 0], curve.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7) |
| 30 | + if curve.dimensions >= 3: |
| 31 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity) |
| 32 | + if vertex_markings: |
| 33 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 34 | + elif curve.dimensions == 2: |
| 35 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity) |
| 36 | + if vertex_markings: |
| 37 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
22 | 38 | else: |
23 | | - p = plt.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, '--o', label = curve.name, markersize = 7, markevery = curve.complexity) |
24 | | - plt.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7) |
| 39 | + p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity) |
| 40 | + if vertex_markings: |
| 41 | + ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
25 | 42 | elif isinstance(curve, backend.Curves): |
26 | 43 | for curv in curve: |
27 | | - if curv.dimensions >= 2: |
28 | | - p = plt.plot(curv.values[:, 0], curv.values[:, 1], '--o', label = curv.name, markersize = 7, markevery = curv.complexity) |
29 | | - plt.plot(curv.values[1:, 0], curv.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7) |
| 44 | + if curv.dimensions >= 3: |
| 45 | + p = ax.plot(curv.values[:, 0], curv.values[:, 1], curv.values[:, 2], linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 46 | + if vertex_markings: |
| 47 | + ax.plot(curv.values[1:, 0], curv.values[1:, 1], curv.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 48 | + elif curv.dimensions == 2: |
| 49 | + p = plt.plot(curv.values[:, 0], curv.values[:, 1], linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 50 | + if vertex_markings: |
| 51 | + plt.plot(curv.values[1:, 0], curv.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
30 | 52 | else: |
31 | | - p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, '--o', label = curv.name, markersize = 7, markevery = curv.complexity) |
32 | | - plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7) |
| 53 | + p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 54 | + if vertex_markings: |
| 55 | + plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
33 | 56 | elif isinstance(curve, backend.Clustering_Result): |
34 | 57 | for curv in curve: |
35 | | - if curv.dimensions >= 2: |
36 | | - p = plt.plot(curv.values[:, 0], curv.values[:, 1], '-o', label = curv.name, markersize = 7, markevery = curv.complexity) |
37 | | - plt.plot(curv.values[1:, 0], curv.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7) |
| 58 | + if curv.dimensions >= 3: |
| 59 | + p = ax.plot(curv.values[:, 0], curv.values[:, 1], curv.values[:, 2], linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 60 | + if vertex_markings: |
| 61 | + ax.plot(curv.values[1:, 0], curv.values[1:, 1], curv.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 62 | + elif curv.dimensions == 2: |
| 63 | + p = plt.plot(curv.values[:, 0], curv.values[:, 1], linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 64 | + if vertex_markings: |
| 65 | + plt.plot(curv.values[1:, 0], curv.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
38 | 66 | else: |
39 | | - p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, '-o', label = curv.name, markersize = 7, markevery = curv.complexity) |
40 | | - plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7) |
41 | | - plt.legend(title='Curve names:') |
42 | | - plt.title('Fred Curves') |
43 | | - if savename is None: |
| 67 | + p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity) |
| 68 | + if vertex_markings: |
| 69 | + plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 70 | + if legend: |
| 71 | + ax.legend(title='Curve names:') |
| 72 | + ax.set_title('Fred Curves') |
| 73 | + if savename is not None: |
| 74 | + plt.savefig("{}.{}".format(savename, saveextension), dpi=150) |
| 75 | + plt.close() |
| 76 | + elif return_fig: |
| 77 | + return fig |
| 78 | + else: |
44 | 79 | plt.show() |
| 80 | + plt.close() |
| 81 | + |
| 82 | +def plot_clustering(clustering_result, curves, vertex_markings=True, savename=None, saveextension=None, return_fig=False, legend=True): |
| 83 | + if not (isinstance(clustering_result, backend.Clustering_Result) and isinstance(curves, backend.Curves)): |
| 84 | + print("Check parameters!") |
| 85 | + return |
| 86 | + if len(clustering_result.assignment) < 1: |
| 87 | + print("compute_assignment was not called! calling now") |
| 88 | + clustering_result.compute_assignment(curves) |
| 89 | + from mpl_toolkits.mplot3d import Axes3D |
| 90 | + import matplotlib.pyplot as plt |
| 91 | + import matplotlib.colors as mcolors |
| 92 | + colors = list(mcolors.BASE_COLORS) |
| 93 | + if len(clustering_result) > len(colors): |
| 94 | + colors = list(mcolors.TABLEAU_COLORS) |
| 95 | + if len(clustering_result) > len(colors): |
| 96 | + colors = list(mcolors.mcolors.CSS4_COLORS) |
| 97 | + max_compl = 1 |
| 98 | + max_dim = 1 |
| 99 | + fig = plt.figure() |
| 100 | + ax = None |
| 101 | + for curve in curves: |
| 102 | + max_compl = max(max_compl, curve.complexity) |
| 103 | + max_dim = max(max_dim, curve.dimensions) |
| 104 | + if max_dim >= 3: |
| 105 | + ax = fig.gca(projection='3d') |
45 | 106 | else: |
| 107 | + ax = fig.gca() |
| 108 | + for i, curve in enumerate(clustering_result): |
| 109 | + if curve.dimensions >= 3: |
| 110 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 111 | + if vertex_markings: |
| 112 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 113 | + elif curve.dimensions == 2: |
| 114 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 115 | + if vertex_markings: |
| 116 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 117 | + else: |
| 118 | + p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 119 | + if vertex_markings: |
| 120 | + ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 121 | + for i in range(len(clustering_result.assignment)): |
| 122 | + for j in range(clustering_result.assignment.count(i)): |
| 123 | + curve = curves[clustering_result.assignment.get(i,j)] |
| 124 | + if curve.dimensions >= 3: |
| 125 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 126 | + if vertex_markings: |
| 127 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 128 | + elif curve.dimensions == 2: |
| 129 | + p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 130 | + if vertex_markings: |
| 131 | + ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 132 | + else: |
| 133 | + p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity) |
| 134 | + if vertex_markings: |
| 135 | + ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7) |
| 136 | + if legend: |
| 137 | + ax.legend(title='Curve names:') |
| 138 | + ax.set_title('Fred Clustering') |
| 139 | + if savename is not None: |
46 | 140 | plt.savefig("{}.{}".format(savename, saveextension), dpi=150) |
47 | | - plt.close() |
| 141 | + plt.close() |
| 142 | + elif return_fig: |
| 143 | + return fig |
| 144 | + else: |
| 145 | + plt.show() |
| 146 | + plt.close() |
0 commit comments