77
88import logging
99import os
10- import shutil
1110import tarfile
1211import time
1312import urllib .request
2928class RemoteURIType (Enum ):
3029 LOCAL = 1
3130 TILEDB = 2
31+ AWS = 3
3232
3333
3434## Settings
@@ -170,7 +170,7 @@ def _summary_string(self):
170170 summary_str += "\n "
171171 return summary_str
172172
173- def add_data_to_ingestion_time_vs_average_query_accuracy (self ):
173+ def add_data_to_ingestion_time_vs_average_query_accuracy (self , marker = "o" ):
174174 summary = self ._summarize_data ()
175175
176176 for tag , data in summary .items ():
@@ -183,9 +183,9 @@ def add_data_to_ingestion_time_vs_average_query_accuracy(self):
183183 (data ["ingestion" ]["times" ][i ], average_accuracy )
184184 )
185185 x , y = zip (* ingestion_times )
186- plt .scatter (y , x , marker = "o" , label = tag )
186+ plt .scatter (y , x , marker = marker , label = tag )
187187
188- def add_data_to_query_time_vs_accuracy (self ):
188+ def add_data_to_query_time_vs_accuracy (self , marker = "o" ):
189189 summary = self ._summarize_data ()
190190
191191 for tag , data in summary .items ():
@@ -195,7 +195,7 @@ def add_data_to_query_time_vs_accuracy(self):
195195 (data ["query" ]["times" ][i ], data ["query" ]["accuracies" ][i ])
196196 )
197197 x , y = zip (* query_times )
198- plt .plot (y , x , marker = "o" , label = tag )
198+ plt .plot (y , x , marker = marker , label = tag )
199199
200200 def save_charts (self ):
201201 # Plot ingestion.
@@ -239,13 +239,17 @@ def new_timer(self, name):
239239 return timer
240240
241241 def save_charts (self ):
242+ markers = ["o" , "^" , "D" , "*" , "P" , "s" , "2" ]
243+
242244 # Plot ingestion.
243245 plt .figure (figsize = (20 , 12 ))
244246 plt .xlabel ("Average Query Accuracy" )
245247 plt .ylabel ("Time (seconds)" )
246248 plt .title ("Ingestion Time vs Average Query Accuracy" )
247- for timer in self .timers :
248- timer .add_data_to_ingestion_time_vs_average_query_accuracy ()
249+ for idx , timer in self .timers :
250+ timer .add_data_to_ingestion_time_vs_average_query_accuracy (
251+ markers [idx % len (markers )]
252+ )
249253 plt .legend ()
250254 plt .savefig (os .path .join (RESULTS_DIR , "ingestion_time_vs_accuracy.png" ))
251255 plt .close ()
@@ -255,8 +259,8 @@ def save_charts(self):
255259 plt .xlabel ("Accuracy" )
256260 plt .ylabel ("Time (seconds)" )
257261 plt .title ("Query Time vs Accuracy" )
258- for timer in self .timers :
259- timer .add_data_to_query_time_vs_accuracy ()
262+ for idx , timer in self .timers :
263+ timer .add_data_to_query_time_vs_accuracy (markers [ idx % len ( markers )] )
260264 plt .legend ()
261265 plt .savefig (os .path .join (RESULTS_DIR , "query_time_vs_accuracy.png" ))
262266 plt .close ()
@@ -281,32 +285,44 @@ def download_and_extract(url, download_path, extract_path):
281285 logger .info ("Finished extracting files." )
282286
283287
288+ config = {}
289+
290+
284291def get_uri (tag ):
285292 index_name = f"index_{ tag .replace ('=' , '_' )} "
293+ index_uri = ""
286294 if REMOTE_URI_TYPE == RemoteURIType .LOCAL :
287295 index_uri = os .path .join (TEMP_DIR , index_name )
288- logger .info (f"Local URI { index_uri } " )
289- if os .path .exists (index_uri ):
290- shutil .rmtree (index_uri )
291- return index_uri
292296 elif REMOTE_URI_TYPE == RemoteURIType .TILEDB :
293297 from common import create_cloud_uri
294298 from common import setUpCloudToken
295299
296300 setUpCloudToken ()
297301 index_uri = create_cloud_uri (index_name , "local_benchmarks" )
298- logger .info (f"TileDB URI { index_uri } " )
299- Index .delete_index (uri = index_uri , config = tiledb .cloud .Config ())
300- return index_uri
302+
303+ config = tiledb .cloud .Config ()
304+ elif REMOTE_URI_TYPE == RemoteURIType .AWS :
305+ from common import create_cloud_uri
306+ from common import setUpCloudToken
307+
308+ setUpCloudToken ()
309+ index_uri = create_cloud_uri (index_name , "local_benchmarks" , True )
310+
311+ config = {
312+ "vfs.s3.aws_access_key_id" : os .environ ["AWS_ACCESS_KEY_ID" ],
313+ "vfs.s3.aws_secret_access_key" : os .environ ["AWS_SECRET_ACCESS_KEY" ],
314+ "vfs.s3.region" : os .environ ["AWS_REGION" ],
315+ }
301316 else :
302317 raise ValueError (f"Invalid REMOTE_URI_TYPE { REMOTE_URI_TYPE } " )
303318
319+ logger .info (f"index_uri: { index_uri } " )
320+ Index .delete_index (index_uri , config )
321+ return index_uri
304322
305- def cleanup_uri (index_uri ):
306- if REMOTE_URI_TYPE == RemoteURIType .TILEDB :
307- from common import delete_uri
308323
309- delete_uri (uri = index_uri , config = tiledb .cloud .Config ())
324+ def cleanup_uri (index_uri ):
325+ Index .delete_index (index_uri , config )
310326
311327
312328def benchmark_ivf_flat ():
@@ -328,9 +344,7 @@ def benchmark_ivf_flat():
328344 index_type = index_type ,
329345 index_uri = index_uri ,
330346 source_uri = SIFT_BASE_PATH ,
331- config = tiledb .cloud .Config ().dict ()
332- if REMOTE_URI_TYPE is not None
333- else None ,
347+ config = config ,
334348 partitions = partitions ,
335349 training_sampling_policy = TrainingSamplingPolicy .RANDOM ,
336350 )
@@ -370,9 +384,7 @@ def benchmark_vamana():
370384 index_type = index_type ,
371385 index_uri = index_uri ,
372386 source_uri = SIFT_BASE_PATH ,
373- config = tiledb .cloud .Config ().dict ()
374- if REMOTE_URI_TYPE is not None
375- else None ,
387+ config = config ,
376388 l_build = l_build ,
377389 r_max_degree = r_max_degree ,
378390 training_sampling_policy = TrainingSamplingPolicy .RANDOM ,
@@ -414,9 +426,7 @@ def benchmark_ivf_pq():
414426 index_type = index_type ,
415427 index_uri = index_uri ,
416428 source_uri = SIFT_BASE_PATH ,
417- config = tiledb .cloud .Config ().dict ()
418- if REMOTE_URI_TYPE is not None
419- else None ,
429+ config = config ,
420430 partitions = partitions ,
421431 training_sampling_policy = TrainingSamplingPolicy .RANDOM ,
422432 num_subspaces = num_subspaces ,
0 commit comments