Skip to content

Commit 4cd316f

Browse files
authored
Validate the queries vector shape (#164)
1 parent 386304e commit 4cd316f

File tree

6 files changed

+182
-1
lines changed

6 files changed

+182
-1
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
+ self.index_version
4141
].uri
4242
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
43+
self.dimensions = schema.shape[0]
4344
if self.base_size == -1:
4445
self.size = schema.domain.dim(1).domain[1] + 1
4546
else:
@@ -74,6 +75,9 @@ def __init__(
7475
self._ids = read_vector_u64(
7576
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
7677
)
78+
79+
def get_dimensions(self):
80+
return self.dimensions
7781

7882
def query_internal(
7983
self,

apis/python/src/tiledb/vector_search/index.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def __init__(
126126
self.thread_executor = futures.ThreadPoolExecutor()
127127

128128
def query(self, queries: np.ndarray, k, **kwargs):
129+
if queries.ndim != 1 and queries.ndim != 2:
130+
raise TypeError(f"Expected queries to have either 1 or 2 dimensions (i.e. [...] or [[...], [...]]), but it had {queries.ndim} dimensions")
131+
132+
query_dimensions = queries.shape[0] if queries.ndim == 1 else queries.shape[1]
133+
if query_dimensions != self.get_dimensions():
134+
raise TypeError(f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.dimensions} dimensions")
135+
129136
with tiledb.scope_ctx(ctx_or_config=self.config):
130137
if not tiledb.array_exists(self.updates_array_uri):
131138
if self.query_base_array:
@@ -253,6 +260,9 @@ def read_additions(
253260
else:
254261
return None, None, updated_ids
255262

263+
def get_dimensions(self):
264+
raise NotImplementedError
265+
256266
def query_internal(self, queries: np.ndarray, k, **kwargs):
257267
raise NotImplementedError
258268

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ def __init__(
6363
].uri
6464
self.memory_budget = memory_budget
6565

66+
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
67+
self.dimensions = schema.shape[0]
68+
6669
self.dtype = self.group.meta.get("dtype", None)
6770
if self.dtype is None:
68-
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
6971
self.dtype = np.dtype(schema.attr("values").dtype)
7072
else:
7173
self.dtype = np.dtype(self.dtype)
@@ -120,6 +122,9 @@ def __init__(
120122
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
121123
)
122124

125+
def get_dimensions(self):
126+
return self.dimensions
127+
123128
def query_internal(
124129
self,
125130
queries: np.ndarray,

apis/python/test/common.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,47 @@ def groundtruth_read(dataset_dir, nqueries=None):
7070
else:
7171
return I, D
7272

73+
def create_random_dataset_f32_only_data(nb, d, centers, path):
74+
"""
75+
Creates a random float32 dataset containing just a dataset and then writes it to disk.
76+
77+
Parameters
78+
----------
79+
nb: int
80+
Number of points in the dataset
81+
d: int
82+
Dimension of the dataset
83+
nq: int
84+
Number of centers
85+
path: str
86+
Path to write the dataset to
87+
"""
88+
from sklearn.datasets import make_blobs
89+
90+
os.mkdir(path)
91+
X, _ = make_blobs(n_samples=nb, n_features=d, centers=centers, random_state=1)
92+
93+
with open(os.path.join(path, "data.f32bin"), "wb") as f:
94+
np.array([nb, d], dtype="uint32").tofile(f)
95+
X.astype("float32").tofile(f)
7396

7497
def create_random_dataset_f32(nb, d, nq, k, path):
98+
"""
99+
Creates a random float32 dataset containing both a dataset and queries against it, and then writes those to disk.
100+
101+
Parameters
102+
----------
103+
nb: int
104+
Number of points in the dataset
105+
d: int
106+
Dimension of the dataset
107+
nq: int
108+
Number of queries
109+
k: int
110+
Number of nearest neighbors to return
111+
path: str
112+
Path to write the dataset to
113+
"""
75114
import sklearn.model_selection
76115
from sklearn.datasets import make_blobs
77116
from sklearn.neighbors import NearestNeighbors
@@ -104,6 +143,22 @@ def create_random_dataset_f32(nb, d, nq, k, path):
104143

105144

106145
def create_random_dataset_u8(nb, d, nq, k, path):
146+
"""
147+
Creates a random uint8 dataset containing both a dataset and queries against it, and then writes those to disk.
148+
149+
Parameters
150+
----------
151+
nb: int
152+
Number of points in the dataset
153+
d: int
154+
Dimension of the dataset
155+
nq: int
156+
Number of queries
157+
k: int
158+
Number of nearest neighbors to return
159+
path: str
160+
Path to write the dataset to
161+
"""
107162
import sklearn.model_selection
108163
from sklearn.datasets import make_blobs
109164
from sklearn.neighbors import NearestNeighbors

apis/python/test/test_index.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
22
from common import *
3+
import pytest
34

45
import tiledb.vector_search.index as ind
56
from tiledb.vector_search import flat_index, ivf_flat_index
67
from tiledb.vector_search.index import Index
8+
from tiledb.vector_search.ingestion import ingest
9+
from tiledb.vector_search.utils import load_fvecs
710

811
def query_and_check(index, queries, k, expected, **kwargs):
912
for _ in range(3):
@@ -89,3 +92,101 @@ def test_ivf_flat_index(tmp_path):
8992

9093
index = index.consolidate_updates()
9194
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {0, 2, 4}, nprobe=partitions)
95+
96+
def test_index_with_incorrect_dimensions(tmp_path):
97+
indexes = [flat_index, ivf_flat_index]
98+
for index_type in indexes:
99+
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
100+
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))
101+
102+
# Wrong number of dimensions will raise a TypeError.
103+
with pytest.raises(TypeError):
104+
index.query(np.array(1, dtype=np.float32), k=3)
105+
with pytest.raises(TypeError):
106+
index.query(np.array([[[1, 1, 1]]], dtype=np.float32), k=3)
107+
with pytest.raises(TypeError):
108+
index.query(np.array([[[[1, 1, 1]]]], dtype=np.float32), k=3)
109+
110+
# Okay otherwise.
111+
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
112+
index.query(np.array([[1, 1, 1]], dtype=np.float32), k=3)
113+
114+
def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
115+
siftsmall_uri = "test/data/siftsmall/siftsmall_base.fvecs"
116+
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
117+
indexes = ["FLAT", "IVF_FLAT"]
118+
for index_type in indexes:
119+
index_uri = os.path.join(tmp_path, f"sift10k_flat_{index_type}")
120+
index = ingest(
121+
index_type=index_type,
122+
index_uri=index_uri,
123+
source_uri=siftsmall_uri,
124+
source_type = "FVEC",
125+
)
126+
127+
# Wrong number of columns will raise a TypeError.
128+
query_shape = (1, 1)
129+
with pytest.raises(TypeError):
130+
index.query(np.random.rand(*query_shape).astype(np.float32), k=10)
131+
132+
# Okay otherwise.
133+
query_vectors = load_fvecs(queries_uri)
134+
index.query(query_vectors, k=10)
135+
136+
def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
137+
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
138+
# number of columns in the indexed data.
139+
size=1000
140+
indexes = ["FLAT", "IVF_FLAT"]
141+
num_columns_in_vector = [1, 2, 3, 4, 5, 10]
142+
for index_type in indexes:
143+
for num_columns in num_columns_in_vector:
144+
index_uri = os.path.join(tmp_path, f"array_{index_type}_{num_columns}")
145+
dataset_dir = os.path.join(tmp_path, f"dataset_{index_type}_{num_columns}")
146+
create_random_dataset_f32_only_data(nb=size, d=num_columns, centers=1, path=dataset_dir)
147+
index = ingest(index_type=index_type, index_uri=index_uri, source_uri=os.path.join(dataset_dir, "data.f32bin"))
148+
149+
# We have created a dataset with num_columns in each vector. Let's try creating queries
150+
# with different numbers of columns and confirming incorrect ones will throw.
151+
for num_columns_for_query in range(1, num_columns + 2):
152+
query_shape = (1, num_columns_for_query)
153+
query = np.random.rand(*query_shape).astype(np.float32)
154+
if query.shape[1] == num_columns:
155+
index.query(query, k=1)
156+
else:
157+
with pytest.raises(TypeError):
158+
index.query(query, k=1)
159+
160+
# TODO(paris): This will throw with the following error. Fix and re-enable, then remove
161+
# test_index_with_incorrect_num_of_query_columns_in_single_vector_query:
162+
# def array_to_matrix(array: np.ndarray):
163+
# if array.dtype == np.float32:
164+
# > return pyarray_copyto_matrix_f32(array)
165+
# E RuntimeError: Number of dimensions must be two
166+
# Here we test with a query which is just a vector, i.e. [1, 2, 3].
167+
# query = query[0]
168+
# if num_columns_for_query == num_columns:
169+
# index.query(query, k=1)
170+
# else:
171+
# with pytest.raises(TypeError):
172+
# index.query(query, k=1)
173+
174+
def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_path):
175+
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
176+
# number of columns in the indexed data, specifically for a single vector query.
177+
# i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
178+
indexes = [flat_index, ivf_flat_index]
179+
for index_type in indexes:
180+
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
181+
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))
182+
183+
# Wrong number of columns will raise a TypeError.
184+
with pytest.raises(TypeError):
185+
index.query(np.array([1], dtype=np.float32), k=3)
186+
with pytest.raises(TypeError):
187+
index.query(np.array([1, 1], dtype=np.float32), k=3)
188+
with pytest.raises(TypeError):
189+
index.query(np.array([1, 1, 1, 1], dtype=np.float32), k=3)
190+
191+
# Okay otherwise.
192+
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)

src/include/detail/linalg/tdb_partitioned_matrix.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,12 @@ class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
293293
std::get<1>(col_part_view_) = std::get<0>(col_part_view_);
294294
for (size_t i = std::get<0>(col_part_view_); i < total_num_parts_; ++i) {
295295
auto next_part_size = indices_[parts_[i] + 1] - indices_[parts_[i]];
296+
297+
// Continue if this partition is empty
298+
if (next_part_size == 0) {
299+
continue;
300+
}
301+
296302
if ((std::get<1>(col_view_) + next_part_size) >
297303
std::get<0>(col_view_) + max_cols_) {
298304
break;

0 commit comments

Comments
 (0)