Skip to content

Commit 402e045

Browse files
committed
plot functions
1 parent 252b450 commit 402e045

File tree

2 files changed

+120
-21
lines changed

2 files changed

+120
-21
lines changed

Fred/__init__.py

Lines changed: 119 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,145 @@
22

33
config = Config()
44

5-
def plot_curve(*curves, savename=None, saveextension=None):
5+
def plot_curve(*curves, vertex_markings=True, savename=None, saveextension=None, return_fig=False, legend=True):
66
import matplotlib.pyplot as plt
7+
from mpl_toolkits.mplot3d import Axes3D
78
max_compl = 1
9+
max_dim = 1
10+
fig = plt.figure()
11+
ax = None
812
for curve in curves:
913
if isinstance(curve, backend.Curve):
1014
max_compl = max(max_compl, curve.complexity)
15+
max_dim = max(max_dim, curve.dimensions)
1116
elif isinstance(curve, backend.Curves):
1217
for curv in curve:
1318
max_compl = max(max_compl, curv.complexity)
19+
max_dim = max(max_dim, curv.dimensions)
1420
elif isinstance(curve, backend.Clustering_Result):
1521
for curv in curve:
1622
max_compl = max(max_compl, curv.complexity)
23+
max_dim = max(max_dim, curv.dimensions)
24+
if max_dim >= 3:
25+
ax = fig.gca(projection='3d')
26+
else:
27+
ax = fig.gca()
1728
for curve in curves:
1829
if isinstance(curve, backend.Curve):
19-
if curve.dimensions >= 2:
20-
p = plt.plot(curve.values[:, 0], curve.values[:, 1], '--o', label = curve.name, markersize = 7, markevery = curve.complexity)
21-
plt.plot(curve.values[1:, 0], curve.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7)
30+
if curve.dimensions >= 3:
31+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity)
32+
if vertex_markings:
33+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
34+
elif curve.dimensions == 2:
35+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity)
36+
if vertex_markings:
37+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
2238
else:
23-
p = plt.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, '--o', label = curve.name, markersize = 7, markevery = curve.complexity)
24-
plt.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7)
39+
p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle='--', marker='o', label = curve.name, markersize = 7, markevery = curve.complexity)
40+
if vertex_markings:
41+
ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
2542
elif isinstance(curve, backend.Curves):
2643
for curv in curve:
27-
if curv.dimensions >= 2:
28-
p = plt.plot(curv.values[:, 0], curv.values[:, 1], '--o', label = curv.name, markersize = 7, markevery = curv.complexity)
29-
plt.plot(curv.values[1:, 0], curv.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7)
44+
if curv.dimensions >= 3:
45+
p = ax.plot(curv.values[:, 0], curv.values[:, 1], curv.values[:, 2], linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
46+
if vertex_markings:
47+
ax.plot(curv.values[1:, 0], curv.values[1:, 1], curv.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
48+
elif curv.dimensions == 2:
49+
p = plt.plot(curv.values[:, 0], curv.values[:, 1], linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
50+
if vertex_markings:
51+
plt.plot(curv.values[1:, 0], curv.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
3052
else:
31-
p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, '--o', label = curv.name, markersize = 7, markevery = curv.complexity)
32-
plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7)
53+
p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, linestyle='--', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
54+
if vertex_markings:
55+
plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
3356
elif isinstance(curve, backend.Clustering_Result):
3457
for curv in curve:
35-
if curv.dimensions >= 2:
36-
p = plt.plot(curv.values[:, 0], curv.values[:, 1], '-o', label = curv.name, markersize = 7, markevery = curv.complexity)
37-
plt.plot(curv.values[1:, 0], curv.values[1:, 1], 'x', label = None, color = p[0].get_color(), markersize = 7)
58+
if curv.dimensions >= 3:
59+
p = ax.plot(curv.values[:, 0], curv.values[:, 1], curv.values[:, 2], linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
60+
if vertex_markings:
61+
ax.plot(curv.values[1:, 0], curv.values[1:, 1], curv.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
62+
elif curv.dimensions == 2:
63+
p = plt.plot(curv.values[:, 0], curv.values[:, 1], linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
64+
if vertex_markings:
65+
plt.plot(curv.values[1:, 0], curv.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
3866
else:
39-
p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, '-o', label = curv.name, markersize = 7, markevery = curv.complexity)
40-
plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], 'x', label = None, color = p[0].get_color(), markersize = 7)
41-
plt.legend(title='Curve names:')
42-
plt.title('Fred Curves')
43-
if savename is None:
67+
p = plt.plot([i * max_compl / len(curv) for i in range(len(curv))], curv.values, linestyle='-', marker='o', label = curv.name, markersize = 7, markevery = curv.complexity)
68+
if vertex_markings:
69+
plt.plot([i * max_compl / len(curv) for i in range(1, len(curv))], curv.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
70+
if legend:
71+
ax.legend(title='Curve names:')
72+
ax.set_title('Fred Curves')
73+
if savename is not None:
74+
plt.savefig("{}.{}".format(savename, saveextension), dpi=150)
75+
plt.close()
76+
elif return_fig:
77+
return fig
78+
else:
4479
plt.show()
80+
plt.close()
81+
82+
def plot_clustering(clustering_result, curves, vertex_markings=True, savename=None, saveextension=None, return_fig=False, legend=True):
83+
if not (isinstance(clustering_result, backend.Clustering_Result) and isinstance(curves, backend.Curves)):
84+
print("Check parameters!")
85+
return
86+
if len(clustering_result.assignment) < 1:
87+
print("compute_assignment was not called! calling now")
88+
clustering_result.compute_assignment(curves)
89+
from mpl_toolkits.mplot3d import Axes3D
90+
import matplotlib.pyplot as plt
91+
import matplotlib.colors as mcolors
92+
colors = list(mcolors.BASE_COLORS)
93+
if len(clustering_result) > len(colors):
94+
colors = list(mcolors.TABLEAU_COLORS)
95+
if len(clustering_result) > len(colors):
96+
colors = list(mcolors.mcolors.CSS4_COLORS)
97+
max_compl = 1
98+
max_dim = 1
99+
fig = plt.figure()
100+
ax = None
101+
for curve in curves:
102+
max_compl = max(max_compl, curve.complexity)
103+
max_dim = max(max_dim, curve.dimensions)
104+
if max_dim >= 3:
105+
ax = fig.gca(projection='3d')
45106
else:
107+
ax = fig.gca()
108+
for i, curve in enumerate(clustering_result):
109+
if curve.dimensions >= 3:
110+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
111+
if vertex_markings:
112+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
113+
elif curve.dimensions == 2:
114+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
115+
if vertex_markings:
116+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
117+
else:
118+
p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle='-', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
119+
if vertex_markings:
120+
ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
121+
for i in range(len(clustering_result.assignment)):
122+
for j in range(clustering_result.assignment.count(i)):
123+
curve = curves[clustering_result.assignment.get(i,j)]
124+
if curve.dimensions >= 3:
125+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], curve.values[:, 2], linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
126+
if vertex_markings:
127+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], curve.values[:, 2], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
128+
elif curve.dimensions == 2:
129+
p = ax.plot(curve.values[:, 0], curve.values[:, 1], linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
130+
if vertex_markings:
131+
ax.plot(curve.values[1:, 0], curve.values[1:, 1], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
132+
else:
133+
p = ax.plot([i * max_compl / len(curve) for i in range(len(curve))], curve.values, linestyle=':', marker='o', label = curve.name, color = colors[i], markersize = 7, markevery = curve.complexity)
134+
if vertex_markings:
135+
ax.plot([i * max_compl / len(curve) for i in range(1, len(curve))], curve.values[1:], linestyle='', marker='x', label = None, color = p[0].get_color(), markersize = 7)
136+
if legend:
137+
ax.legend(title='Curve names:')
138+
ax.set_title('Fred Clustering')
139+
if savename is not None:
46140
plt.savefig("{}.{}".format(savename, saveextension), dpi=150)
47-
plt.close()
141+
plt.close()
142+
elif return_fig:
143+
return fig
144+
else:
145+
plt.show()
146+
plt.close()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def build_extension(self, ext):
7474

7575
setup(
7676
name='Fred-Frechet',
77-
version='1.9.13',
77+
version='1.9.20',
7878
author='Dennis Rohde',
7979
author_email='dennis.rohde@tu-dortmund.de',
8080
description='A fast, scalable and light-weight C++ Fréchet distance library, exposed to python and focused on (k,l)-clustering of polygonal curves.',

0 commit comments

Comments
 (0)