Skip to content

Commit 9c029e9

Browse files
authored
Small refactor of unit tests to consolidate language and add a query and check helper to test_ingestion.py (#176)
1 parent f5734f9 commit 9c029e9

File tree

3 files changed

+131
-138
lines changed

3 files changed

+131
-138
lines changed

apis/python/test/test_cloud.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_cloud_flat(self):
4040
k = 100
4141
nqueries = 100
4242

43-
query_vectors = load_fvecs(queries_uri)
43+
queries = load_fvecs(queries_uri)
4444
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
4545

4646
index = vs.ingest(
@@ -50,7 +50,7 @@ def test_cloud_flat(self):
5050
config=tiledb.cloud.Config().dict(),
5151
mode=Mode.BATCH,
5252
)
53-
_, result_i = index.query(query_vectors, k=k)
53+
_, result_i = index.query(queries, k=k)
5454
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
5555

5656
def test_cloud_ivf_flat(self):
@@ -63,7 +63,7 @@ def test_cloud_ivf_flat(self):
6363
nqueries = 100
6464
nprobe = 20
6565

66-
query_vectors = load_fvecs(queries_uri)
66+
queries = load_fvecs(queries_uri)
6767
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
6868

6969
index = vs.ingest(
@@ -79,16 +79,16 @@ def test_cloud_ivf_flat(self):
7979
# mode=Mode.BATCH,
8080
)
8181

82-
_, result_i = index.query(query_vectors, k=k, nprobe=nprobe)
82+
_, result_i = index.query(queries, k=k, nprobe=nprobe)
8383
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
8484

8585
_, result_i = index.query(
86-
query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2, resource_class="standard"
86+
queries, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2, resource_class="standard"
8787
)
8888
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
8989

9090
_, result_i = index.query(
91-
query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, num_partitions=2
91+
queries, k=k, nprobe=nprobe, mode=Mode.LOCAL, num_partitions=2
9292
)
9393
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
9494

@@ -97,20 +97,20 @@ def test_cloud_ivf_flat(self):
9797

9898
# Cannot pass resource_class or resources to LOCAL mode or to no mode.
9999
with self.assertRaises(TypeError):
100-
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resource_class="large")
100+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.LOCAL, resource_class="large")
101101
with self.assertRaises(TypeError):
102-
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resources=resources)
102+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.LOCAL, resources=resources)
103103
with self.assertRaises(TypeError):
104-
index.query(query_vectors, k=k, nprobe=nprobe, resource_class="large")
104+
index.query(queries, k=k, nprobe=nprobe, resource_class="large")
105105
with self.assertRaises(TypeError):
106-
index.query(query_vectors, k=k, nprobe=nprobe, resources=resources)
106+
index.query(queries, k=k, nprobe=nprobe, resources=resources)
107107

108108
# Cannot pass resources to REALTIME.
109109
with self.assertRaises(TypeError):
110-
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resources=resources)
110+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.REALTIME, resources=resources)
111111

112112
# Cannot pass both resource_class and resources.
113113
with self.assertRaises(TypeError):
114-
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resource_class="large", resources=resources)
114+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.REALTIME, resource_class="large", resources=resources)
115115
with self.assertRaises(TypeError):
116-
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)
116+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)

apis/python/test/test_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
131131
index.query(np.random.rand(*query_shape).astype(np.float32), k=10)
132132

133133
# Okay otherwise.
134-
query_vectors = load_fvecs(queries_uri)
135-
index.query(query_vectors, k=10)
134+
queries = load_fvecs(queries_uri)
135+
index.query(queries, k=10)
136136

137137
def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
138138
# Tests that we raise a TypeError if the number of columns in the query is not the same as the

0 commit comments

Comments
 (0)