Skip to content

Commit 15f83d8

Browse files
authored
Enforce queries has two dimensions (#167)
1 parent b8c68d3 commit 15f83d8

File tree

3 files changed

+4
-49
lines changed

3 files changed

+4
-49
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def __init__(
126126
self.thread_executor = futures.ThreadPoolExecutor()
127127

128128
def query(self, queries: np.ndarray, k, **kwargs):
129-
if queries.ndim != 1 and queries.ndim != 2:
130-
raise TypeError(f"Expected queries to have either 1 or 2 dimensions (i.e. [...] or [[...], [...]]), but it had {queries.ndim} dimensions")
129+
if queries.ndim != 2:
130+
raise TypeError(f"Expected queries to have 2 dimensions (i.e. [[...], etc.]), but it had {queries.ndim} dimensions")
131131

132132
query_dimensions = queries.shape[0] if queries.ndim == 1 else queries.shape[1]
133133
if query_dimensions != self.get_dimensions():

apis/python/test/test_index.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,14 @@ def test_index_with_incorrect_dimensions(tmp_path):
102102
# Wrong number of dimensions will raise a TypeError.
103103
with pytest.raises(TypeError):
104104
index.query(np.array(1, dtype=np.float32), k=3)
105+
with pytest.raises(TypeError):
106+
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
105107
with pytest.raises(TypeError):
106108
index.query(np.array([[[1, 1, 1]]], dtype=np.float32), k=3)
107109
with pytest.raises(TypeError):
108110
index.query(np.array([[[[1, 1, 1]]]], dtype=np.float32), k=3)
109111

110112
# Okay otherwise.
111-
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
112113
index.query(np.array([[1, 1, 1]], dtype=np.float32), k=3)
113114

114115
def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
@@ -156,37 +157,3 @@ def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
156157
else:
157158
with pytest.raises(TypeError):
158159
index.query(query, k=1)
159-
160-
# TODO(paris): This will throw with the following error. Fix and re-enable, then remove
161-
# test_index_with_incorrect_num_of_query_columns_in_single_vector_query:
162-
# def array_to_matrix(array: np.ndarray):
163-
# if array.dtype == np.float32:
164-
# > return pyarray_copyto_matrix_f32(array)
165-
# E RuntimeError: Number of dimensions must be two
166-
# Here we test with a query which is just a vector, i.e. [1, 2, 3].
167-
# query = query[0]
168-
# if num_columns_for_query == num_columns:
169-
# index.query(query, k=1)
170-
# else:
171-
# with pytest.raises(TypeError):
172-
# index.query(query, k=1)
173-
174-
def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_path):
175-
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
176-
# number of columns in the indexed data, specifically for a single vector query.
177-
# i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
178-
indexes = [flat_index, ivf_flat_index]
179-
for index_type in indexes:
180-
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
181-
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))
182-
183-
# Wrong number of columns will raise a TypeError.
184-
with pytest.raises(TypeError):
185-
index.query(np.array([1], dtype=np.float32), k=3)
186-
with pytest.raises(TypeError):
187-
index.query(np.array([1, 1], dtype=np.float32), k=3)
188-
with pytest.raises(TypeError):
189-
index.query(np.array([1, 1, 1, 1], dtype=np.float32), k=3)
190-
191-
# Okay otherwise.
192-
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)

apis/python/test/test_ingestion.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
197197
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
198198
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
199199

200-
# Test single query vector handling
201-
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
202-
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
203-
204200
index_ram = IVFFlatIndex(uri=index_uri)
205201
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
206202
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
@@ -242,10 +238,6 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
242238
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
243239
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
244240

245-
# Test single query vector handling
246-
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
247-
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
248-
249241
index_ram = IVFFlatIndex(uri=index_uri)
250242
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
251243
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
@@ -286,10 +278,6 @@ def test_ivf_flat_ingestion_multiple_workers(tmp_path):
286278
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
287279
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
288280

289-
# Test single query vector handling
290-
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
291-
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
292-
293281
index_ram = IVFFlatIndex(uri=index_uri)
294282
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
295283
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

0 commit comments

Comments
 (0)