Skip to content

Commit 7550865

Browse files
authored
Add new plot to local-benchmarks.py showing results from all indexes (#512)
1 parent 6bb588e commit 7550865

File tree

1 file changed

+67
-23
lines changed

1 file changed

+67
-23
lines changed

apis/python/test/local-benchmarks.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def accuracy(self, tag, acc):
113113
self.tagToAccuracies[tag].append(acc)
114114
return acc
115115

116-
def summarize_data(self):
116+
def _summarize_data(self):
117117
summary = {}
118118
for key, intervals in self.keyToTimes.items():
119119
tag, mode = key.rsplit("_", 1)
@@ -144,8 +144,8 @@ def summarize_data(self):
144144

145145
return summary
146146

147-
def summary_string(self):
148-
summary = self.summarize_data()
147+
def _summary_string(self):
148+
summary = self._summarize_data()
149149
summary_str = f"Timer: {self.name}\n"
150150
for tag, data in summary.items():
151151
summary_str += f"{tag}\n"
@@ -160,14 +160,9 @@ def summary_string(self):
160160
summary_str += "\n"
161161
return summary_str
162162

163-
def save_charts(self):
164-
summary = self.summarize_data()
163+
def add_data_to_ingestion_time_vs_average_query_accuracy(self):
164+
summary = self._summarize_data()
165165

166-
# Plot ingestion.
167-
plt.figure(figsize=(20, 12))
168-
plt.xlabel("Average Query Accuracy")
169-
plt.ylabel("Time (seconds)")
170-
plt.title(f"{self.name}: Ingestion Time vs Average Query Accuracy")
171166
for tag, data in summary.items():
172167
ingestion_times = []
173168
average_accuracy = sum(data["query"]["accuracies"]) / len(
@@ -180,6 +175,25 @@ def save_charts(self):
180175
x, y = zip(*ingestion_times)
181176
plt.scatter(y, x, marker="o", label=tag)
182177

178+
def add_data_to_query_time_vs_accuracy(self):
179+
summary = self._summarize_data()
180+
181+
for tag, data in summary.items():
182+
query_times = []
183+
for i in range(data["query"]["count"]):
184+
query_times.append(
185+
(data["query"]["times"][i], data["query"]["accuracies"][i])
186+
)
187+
x, y = zip(*query_times)
188+
plt.plot(y, x, marker="o", label=tag)
189+
190+
def save_charts(self):
191+
# Plot ingestion.
192+
plt.figure(figsize=(20, 12))
193+
plt.xlabel("Average Query Accuracy")
194+
plt.ylabel("Time (seconds)")
195+
plt.title(f"{self.name}: Ingestion Time vs Average Query Accuracy")
196+
self.add_data_to_ingestion_time_vs_average_query_accuracy()
183197
plt.legend()
184198
plt.savefig(
185199
os.path.join(RESULTS_DIR, f"{self.name}_ingestion_time_vs_accuracy.png")
@@ -191,28 +205,56 @@ def save_charts(self):
191205
plt.xlabel("Accuracy")
192206
plt.ylabel("Time (seconds)")
193207
plt.title(f"{self.name}: Query Time vs Accuracy")
194-
for tag, data in summary.items():
195-
query_times = []
196-
for i in range(data["query"]["count"]):
197-
query_times.append(
198-
(data["query"]["times"][i], data["query"]["accuracies"][i])
199-
)
200-
x, y = zip(*query_times)
201-
plt.plot(y, x, marker="o", label=tag)
202-
208+
self.add_data_to_query_time_vs_accuracy()
203209
plt.legend()
204210
plt.savefig(
205211
os.path.join(RESULTS_DIR, f"{self.name}_query_time_vs_accuracy.png")
206212
)
207213
plt.close()
208214

209215
def save_and_print_results(self):
210-
summary_string = self.summary_string()
216+
summary_string = self._summary_string()
211217
logger.info(summary_string)
212218

213219
self.save_charts()
214220

215221

222+
class TimerManager:
223+
def __init__(self):
224+
self.timers = []
225+
226+
def new_timer(self, name):
227+
timer = Timer(name)
228+
self.timers.append(timer)
229+
return timer
230+
231+
def save_charts(self):
232+
# Plot ingestion.
233+
plt.figure(figsize=(20, 12))
234+
plt.xlabel("Average Query Accuracy")
235+
plt.ylabel("Time (seconds)")
236+
plt.title("Ingestion Time vs Average Query Accuracy")
237+
for timer in self.timers:
238+
timer.add_data_to_ingestion_time_vs_average_query_accuracy()
239+
plt.legend()
240+
plt.savefig(os.path.join(RESULTS_DIR, "ingestion_time_vs_accuracy.png"))
241+
plt.close()
242+
243+
# Plot query.
244+
plt.figure(figsize=(20, 12))
245+
plt.xlabel("Accuracy")
246+
plt.ylabel("Time (seconds)")
247+
plt.title("Query Time vs Accuracy")
248+
for timer in self.timers:
249+
timer.add_data_to_query_time_vs_accuracy()
250+
plt.legend()
251+
plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png"))
252+
plt.close()
253+
254+
255+
timer_manager = TimerManager()
256+
257+
216258
def download_and_extract(url, download_path, extract_path):
217259
if os.path.exists(download_path):
218260
logger.info(
@@ -231,7 +273,7 @@ def download_and_extract(url, download_path, extract_path):
231273

232274
def benchmark_ivf_flat():
233275
index_type = "IVF_FLAT"
234-
timer = Timer(name=index_type)
276+
timer = timer_manager.new_timer(index_type)
235277

236278
k = 100
237279
queries = load_fvecs(SIFT_QUERIES_PATH)
@@ -269,7 +311,7 @@ def benchmark_ivf_flat():
269311

270312
def benchmark_vamana():
271313
index_type = "VAMANA"
272-
timer = Timer(name=index_type)
314+
timer = timer_manager.new_timer(index_type)
273315

274316
k = 100
275317
queries = load_fvecs(SIFT_QUERIES_PATH)
@@ -309,7 +351,7 @@ def benchmark_vamana():
309351

310352
def benchmark_ivf_pq():
311353
index_type = "IVF_PQ"
312-
timer = Timer(name=index_type)
354+
timer = timer_manager.new_timer(index_type)
313355

314356
k = 100
315357
queries = load_fvecs(SIFT_QUERIES_PATH)
@@ -355,5 +397,7 @@ def main():
355397
benchmark_vamana()
356398
benchmark_ivf_pq()
357399

400+
timer_manager.save_charts()
401+
358402

359403
main()

0 commit comments

Comments
 (0)