Skip to content

Commit 701b9cc

Browse files
authored
Support AWS index URI in local-benchmarks.py (#516)
1 parent 8b4a53e commit 701b9cc

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

apis/python/test/common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,17 @@ def setUpCloudToken():
394394
tiledb.cloud.login(token=token)
395395

396396

397-
def create_cloud_uri(name, folder_name=None):
397+
def create_cloud_uri(name, folder_name=None, aws_uri=False):
398398
namespace, storage_path, _ = groups._default_ns_path_cred()
399399
storage_path = storage_path.replace("//", "/").replace("/", "//", 1)
400+
400401
if not folder_name:
401402
folder_name = random_name("vector_search")
402-
test_path = f"tiledb://{namespace}/{storage_path}/{folder_name}"
403-
return f"{test_path}/{name}"
403+
404+
if aws_uri:
405+
return f"{storage_path}/{folder_name}/{name}"
406+
else:
407+
return f"tiledb://{namespace}/{storage_path}/{folder_name}/{name}"
404408

405409

406410
def delete_uri(uri, config):

apis/python/test/local-benchmarks.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import logging
99
import os
10-
import shutil
1110
import tarfile
1211
import time
1312
import urllib.request
@@ -29,6 +28,7 @@
2928
class RemoteURIType(Enum):
3029
LOCAL = 1
3130
TILEDB = 2
31+
AWS = 3
3232

3333

3434
## Settings
@@ -170,7 +170,7 @@ def _summary_string(self):
170170
summary_str += "\n"
171171
return summary_str
172172

173-
def add_data_to_ingestion_time_vs_average_query_accuracy(self):
173+
def add_data_to_ingestion_time_vs_average_query_accuracy(self, marker="o"):
174174
summary = self._summarize_data()
175175

176176
for tag, data in summary.items():
@@ -183,9 +183,9 @@ def add_data_to_ingestion_time_vs_average_query_accuracy(self):
183183
(data["ingestion"]["times"][i], average_accuracy)
184184
)
185185
x, y = zip(*ingestion_times)
186-
plt.scatter(y, x, marker="o", label=tag)
186+
plt.scatter(y, x, marker=marker, label=tag)
187187

188-
def add_data_to_query_time_vs_accuracy(self):
188+
def add_data_to_query_time_vs_accuracy(self, marker="o"):
189189
summary = self._summarize_data()
190190

191191
for tag, data in summary.items():
@@ -195,7 +195,7 @@ def add_data_to_query_time_vs_accuracy(self):
195195
(data["query"]["times"][i], data["query"]["accuracies"][i])
196196
)
197197
x, y = zip(*query_times)
198-
plt.plot(y, x, marker="o", label=tag)
198+
plt.plot(y, x, marker=marker, label=tag)
199199

200200
def save_charts(self):
201201
# Plot ingestion.
@@ -239,13 +239,17 @@ def new_timer(self, name):
239239
return timer
240240

241241
def save_charts(self):
242+
markers = ["o", "^", "D", "*", "P", "s", "2"]
243+
242244
# Plot ingestion.
243245
plt.figure(figsize=(20, 12))
244246
plt.xlabel("Average Query Accuracy")
245247
plt.ylabel("Time (seconds)")
246248
plt.title("Ingestion Time vs Average Query Accuracy")
247-
for timer in self.timers:
248-
timer.add_data_to_ingestion_time_vs_average_query_accuracy()
249+
for idx, timer in self.timers:
250+
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
251+
markers[idx % len(markers)]
252+
)
249253
plt.legend()
250254
plt.savefig(os.path.join(RESULTS_DIR, "ingestion_time_vs_accuracy.png"))
251255
plt.close()
@@ -255,8 +259,8 @@ def save_charts(self):
255259
plt.xlabel("Accuracy")
256260
plt.ylabel("Time (seconds)")
257261
plt.title("Query Time vs Accuracy")
258-
for timer in self.timers:
259-
timer.add_data_to_query_time_vs_accuracy()
262+
for idx, timer in self.timers:
263+
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
260264
plt.legend()
261265
plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png"))
262266
plt.close()
@@ -281,32 +285,44 @@ def download_and_extract(url, download_path, extract_path):
281285
logger.info("Finished extracting files.")
282286

283287

288+
config = {}
289+
290+
284291
def get_uri(tag):
285292
index_name = f"index_{tag.replace('=', '_')}"
293+
index_uri = ""
286294
if REMOTE_URI_TYPE == RemoteURIType.LOCAL:
287295
index_uri = os.path.join(TEMP_DIR, index_name)
288-
logger.info(f"Local URI {index_uri}")
289-
if os.path.exists(index_uri):
290-
shutil.rmtree(index_uri)
291-
return index_uri
292296
elif REMOTE_URI_TYPE == RemoteURIType.TILEDB:
293297
from common import create_cloud_uri
294298
from common import setUpCloudToken
295299

296300
setUpCloudToken()
297301
index_uri = create_cloud_uri(index_name, "local_benchmarks")
298-
logger.info(f"TileDB URI {index_uri}")
299-
Index.delete_index(uri=index_uri, config=tiledb.cloud.Config())
300-
return index_uri
302+
303+
config = tiledb.cloud.Config()
304+
elif REMOTE_URI_TYPE == RemoteURIType.AWS:
305+
from common import create_cloud_uri
306+
from common import setUpCloudToken
307+
308+
setUpCloudToken()
309+
index_uri = create_cloud_uri(index_name, "local_benchmarks", True)
310+
311+
config = {
312+
"vfs.s3.aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID"],
313+
"vfs.s3.aws_secret_access_key": os.environ["AWS_SECRET_ACCESS_KEY"],
314+
"vfs.s3.region": os.environ["AWS_REGION"],
315+
}
301316
else:
302317
raise ValueError(f"Invalid REMOTE_URI_TYPE {REMOTE_URI_TYPE}")
303318

319+
logger.info(f"index_uri: {index_uri}")
320+
Index.delete_index(index_uri, config)
321+
return index_uri
304322

305-
def cleanup_uri(index_uri):
306-
if REMOTE_URI_TYPE == RemoteURIType.TILEDB:
307-
from common import delete_uri
308323

309-
delete_uri(uri=index_uri, config=tiledb.cloud.Config())
324+
def cleanup_uri(index_uri):
325+
Index.delete_index(index_uri, config)
310326

311327

312328
def benchmark_ivf_flat():
@@ -328,9 +344,7 @@ def benchmark_ivf_flat():
328344
index_type=index_type,
329345
index_uri=index_uri,
330346
source_uri=SIFT_BASE_PATH,
331-
config=tiledb.cloud.Config().dict()
332-
if REMOTE_URI_TYPE is not None
333-
else None,
347+
config=config,
334348
partitions=partitions,
335349
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
336350
)
@@ -370,9 +384,7 @@ def benchmark_vamana():
370384
index_type=index_type,
371385
index_uri=index_uri,
372386
source_uri=SIFT_BASE_PATH,
373-
config=tiledb.cloud.Config().dict()
374-
if REMOTE_URI_TYPE is not None
375-
else None,
387+
config=config,
376388
l_build=l_build,
377389
r_max_degree=r_max_degree,
378390
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
@@ -414,9 +426,7 @@ def benchmark_ivf_pq():
414426
index_type=index_type,
415427
index_uri=index_uri,
416428
source_uri=SIFT_BASE_PATH,
417-
config=tiledb.cloud.Config().dict()
418-
if REMOTE_URI_TYPE is not None
419-
else None,
429+
config=config,
420430
partitions=partitions,
421431
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
422432
num_subspaces=num_subspaces,

0 commit comments

Comments
 (0)