Skip to content

Commit 5aee8a2

Browse files
Add support for int8 type indexes (#342)
This adds support for `int8` type vectors. `int8` scalar quantization is the most popular configuration for scalar quantization of signed floating point values. https://huggingface.co/blog/embedding-quantization
1 parent 14d9b80 commit 5aee8a2

File tree

5 files changed

+133
-6
lines changed

5 files changed

+133
-6
lines changed

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@ bool enable_stats = false;
2525
std::vector<json> core_stats;
2626

2727
PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
28+
PYBIND11_MAKE_OPAQUE(std::vector<int8_t>);
2829
PYBIND11_MAKE_OPAQUE(std::vector<uint32_t>);
2930
PYBIND11_MAKE_OPAQUE(std::vector<uint64_t>);
3031
PYBIND11_MAKE_OPAQUE(std::vector<float>);
3132
PYBIND11_MAKE_OPAQUE(std::vector<double>);
3233
PYBIND11_MAKE_OPAQUE(std::list<std::vector<uint8_t>>);
34+
PYBIND11_MAKE_OPAQUE(std::list<std::vector<int8_t>>);
3335
PYBIND11_MAKE_OPAQUE(std::list<std::vector<uint32_t>>);
3436
PYBIND11_MAKE_OPAQUE(std::list<std::vector<uint64_t>>);
3537
PYBIND11_MAKE_OPAQUE(std::list<std::vector<float>>);
3638
PYBIND11_MAKE_OPAQUE(std::list<std::vector<double>>);
3739
PYBIND11_MAKE_OPAQUE(std::vector<std::list<uint8_t>>);
40+
PYBIND11_MAKE_OPAQUE(std::vector<std::list<int8_t>>);
3841
PYBIND11_MAKE_OPAQUE(std::vector<std::list<uint32_t>>);
3942
PYBIND11_MAKE_OPAQUE(std::vector<std::list<uint64_t>>);
4043
PYBIND11_MAKE_OPAQUE(std::vector<std::list<float>>);
@@ -627,6 +630,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
627630
declareStdVector<float>(m, "f32");
628631
declareStdVector<double>(m, "f64");
629632
declareStdVector<uint8_t>(m, "u8");
633+
declareStdVector<int8_t>(m, "i8");
630634
declareStdVector<uint32_t>(m, "u32");
631635
declareStdVector<uint64_t>(m, "u64");
632636
if constexpr (!std::is_same_v<uint64_t, size_t>) {
@@ -664,6 +668,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
664668
/* === Matrix === */
665669

666670
declareColMajorMatrix<uint8_t>(m, "_u8");
671+
declareColMajorMatrix<int8_t>(m, "_i8");
667672
declareColMajorMatrix<float>(m, "_f32");
668673
declareColMajorMatrix<double>(m, "_f64");
669674
declareColMajorMatrix<int32_t>(m, "_i32");
@@ -677,6 +682,8 @@ PYBIND11_MODULE(_tiledbvspy, m) {
677682

678683
declareColMajorMatrixSubclass<tdbColMajorMatrix<uint8_t>>(
679684
m, "tdbColMajorMatrix", "_u8");
685+
declareColMajorMatrixSubclass<tdbColMajorMatrix<int8_t>>(
686+
m, "tdbColMajorMatrix", "_i8");
680687
declareColMajorMatrixSubclass<tdbColMajorMatrix<uint64_t>>(
681688
m, "tdbColMajorMatrix", "_u64");
682689
declareColMajorMatrixSubclass<tdbColMajorMatrix<float>>(
@@ -688,6 +695,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
688695

689696
// Converters from pyarray to matrix
690697
declare_pyarray_to_matrix<uint8_t>(m, "_u8");
698+
declare_pyarray_to_matrix<int8_t>(m, "_i8");
691699
declare_pyarray_to_matrix<uint64_t>(m, "_u64");
692700
declare_pyarray_to_matrix<float>(m, "_f32");
693701
declare_pyarray_to_matrix<double>(m, "_f64");
@@ -716,6 +724,17 @@ PYBIND11_MODULE(_tiledbvspy, m) {
716724
return r;
717725
});
718726

727+
m.def(
728+
"query_vq_i8",
729+
[](tdbColMajorMatrix<int8_t>& data,
730+
ColMajorMatrix<float>& query_vectors,
731+
int k,
732+
size_t nthreads)
733+
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
734+
auto r = detail::flat::vq_query_heap(data, query_vectors, k, nthreads);
735+
return r;
736+
});
737+
719738
m.def(
720739
"validate_top_k_u64",
721740
[](const ColMajorMatrix<uint64_t>& top_k,
@@ -724,33 +743,45 @@ PYBIND11_MODULE(_tiledbvspy, m) {
724743
});
725744

726745
declare_vq_query_heap<uint8_t>(m, "u8");
746+
declare_vq_query_heap<int8_t>(m, "i8");
727747
declare_vq_query_heap<float>(m, "f32");
728748
declare_vq_query_heap_pyarray<uint8_t>(m, "u8");
749+
declare_vq_query_heap_pyarray<int8_t>(m, "i8");
729750
declare_vq_query_heap_pyarray<float>(m, "f32");
730751

731752
declare_qv_query_heap_infinite_ram<uint8_t>(m, "u8");
753+
declare_qv_query_heap_infinite_ram<int8_t>(m, "i8");
732754
declare_qv_query_heap_infinite_ram<float>(m, "f32");
733755
declare_qv_query_heap_finite_ram<uint8_t>(m, "u8");
756+
declare_qv_query_heap_finite_ram<int8_t>(m, "i8");
734757
declare_qv_query_heap_finite_ram<float>(m, "f32");
735758
declare_nuv_query_heap_infinite_ram<uint8_t>(m, "u8");
759+
declare_nuv_query_heap_infinite_ram<int8_t>(m, "i8");
736760
declare_nuv_query_heap_infinite_ram<float>(m, "f32");
737761
declare_nuv_query_heap_finite_ram<uint8_t>(m, "u8");
762+
declare_nuv_query_heap_finite_ram<int8_t>(m, "i8");
738763
declare_nuv_query_heap_finite_ram<float>(m, "f32");
739764

740765
declare_ivf_index<uint8_t>(m, "u8");
766+
declare_ivf_index<int8_t>(m, "i8");
741767
declare_ivf_index<float>(m, "f32");
742768
declare_ivf_index_tdb<uint8_t>(m, "u8");
769+
declare_ivf_index_tdb<int8_t>(m, "i8");
743770
declare_ivf_index_tdb<float>(m, "f32");
744771

745772
declarePartitionIvfIndex<uint8_t>(m, "u8");
773+
declarePartitionIvfIndex<int8_t>(m, "i8");
746774
declarePartitionIvfIndex<float>(m, "f32");
747775

748776
declarePartitionedMatrix<uint8_t, uint64_t, uint64_t, uint64_t>(
749777
m, "tdbPartitionedMatrix", "u8");
778+
declarePartitionedMatrix<int8_t, uint64_t, uint64_t, uint64_t>(
779+
m, "tdbPartitionedMatrix", "i8");
750780
declarePartitionedMatrix<float, uint64_t, uint64_t, uint64_t>(
751781
m, "tdbPartitionedMatrix", "f32");
752782

753783
declare_dist_qv<uint8_t>(m, "u8");
784+
declare_dist_qv<int8_t>(m, "i8");
754785
declare_dist_qv<float>(m, "f32");
755786
declareFixedMinPairHeap(m);
756787

@@ -770,6 +801,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
770801
m.def("stats_dump", []() { return json{core_stats}.dump(); });
771802

772803
declare_debug_slice<uint8_t>(m, "_u8");
804+
declare_debug_slice<int8_t>(m, "_i8");
773805
declare_debug_slice<float>(m, "_f32");
774806
declare_debug_slice<uint64_t>(m, "_u64");
775807

apis/python/src/tiledb/vector_search/module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def load_as_matrix(
4848
m = tdbColMajorMatrix_i64(ctx, path, 0, None, 0, size, 0, timestamp)
4949
elif dtype == np.uint8:
5050
m = tdbColMajorMatrix_u8(ctx, path, 0, None, 0, size, 0, timestamp)
51+
elif dtype == np.int8:
52+
m = tdbColMajorMatrix_i8(ctx, path, 0, None, 0, size, 0, timestamp)
5153
# elif dtype == np.uint64:
5254
# return tdbColMajorMatrix_u64(ctx, path, size, timestamp)
5355
else:
@@ -91,6 +93,8 @@ def debug_slice(m: "colMajorMatrix", name: str):
9193
return debug_slice_f32(m, name)
9294
elif dtype == np.uint8:
9395
return debug_slice_u8(m, name)
96+
elif dtype == np.int8:
97+
return debug_slice_i8(m, name)
9498
elif dtype == np.uint64:
9599
return debug_slice_u64(m, name)
96100
else:
@@ -112,6 +116,8 @@ def query_vq_nth(db: "colMajorMatrix", *args):
112116
return query_vq_f32(db, *args)
113117
elif db.dtype == np.uint8:
114118
return query_vq_u8(db, *args)
119+
elif db.dtype == np.int8:
120+
return query_vq_i8(db, *args)
115121
else:
116122
raise TypeError("Unknown type!")
117123

@@ -131,6 +137,8 @@ def query_vq_heap(db: "colMajorMatrix", *args):
131137
return vq_query_heap_f32(db, *args)
132138
elif db.dtype == np.uint8:
133139
return vq_query_heap_u8(db, *args)
140+
elif db.dtype == np.int8:
141+
return vq_query_heap_i8(db, *args)
134142
else:
135143
raise TypeError("Unknown type!")
136144

@@ -150,6 +158,8 @@ def query_vq_heap_pyarray(db: "colMajorMatrix", *args):
150158
return vq_query_heap_pyarray_f32(db, *args)
151159
elif db.dtype == np.uint8:
152160
return vq_query_heap_pyarray_u8(db, *args)
161+
elif db.dtype == np.int8:
162+
return vq_query_heap_pyarray_i8(db, *args)
153163
else:
154164
raise TypeError("Unknown type!")
155165

@@ -195,6 +205,8 @@ def ivf_index_tdb(
195205
return ivf_index_tdb_f32(*args)
196206
elif dtype == np.uint8:
197207
return ivf_index_tdb_u8(*args)
208+
elif dtype == np.int8:
209+
return ivf_index_tdb_i8(*args)
198210
else:
199211
raise TypeError("Unknown type!")
200212

@@ -240,6 +252,8 @@ def ivf_index(
240252
return ivf_index_f32(*args)
241253
elif dtype == np.uint8:
242254
return ivf_index_u8(*args)
255+
elif dtype == np.int8:
256+
return ivf_index_i8(*args)
243257
else:
244258
raise TypeError("Unknown type!")
245259

@@ -312,6 +326,11 @@ def ivf_query_ram(
312326
return nuv_query_heap_infinite_ram_reg_blocked_u8(*args)
313327
else:
314328
return qv_query_heap_infinite_ram_u8(*args)
329+
elif dtype == np.int8:
330+
if use_nuv_implementation:
331+
return nuv_query_heap_infinite_ram_reg_blocked_i8(*args)
332+
else:
333+
return qv_query_heap_infinite_ram_i8(*args)
315334
else:
316335
raise TypeError("Unknown type!")
317336

@@ -394,6 +413,11 @@ def ivf_query(
394413
return nuv_query_heap_finite_ram_reg_blocked_u8(*args)
395414
else:
396415
return qv_query_heap_finite_ram_u8(*args)
416+
elif dtype == np.int8:
417+
if use_nuv_implementation:
418+
return nuv_query_heap_finite_ram_reg_blocked_i8(*args)
419+
else:
420+
return qv_query_heap_finite_ram_i8(*args)
397421
else:
398422
raise TypeError("Unknown type!")
399423

@@ -403,6 +427,8 @@ def partition_ivf_index(centroids, query, nprobe=1, nthreads=0):
403427
return partition_ivf_index_f32(centroids, query, nprobe, nthreads)
404428
elif query.dtype == np.uint8:
405429
return partition_ivf_index_u8(centroids, query, nprobe, nthreads)
430+
elif query.dtype == np.int8:
431+
return partition_ivf_index_i8(centroids, query, nprobe, nthreads)
406432
else:
407433
raise TypeError("Unsupported type!")
408434

@@ -442,6 +468,8 @@ def dist_qv(
442468
return dist_qv_f32(*args)
443469
elif dtype == np.uint8:
444470
return dist_qv_u8(*args)
471+
elif dtype == np.int8:
472+
return dist_qv_i8(*args)
445473
else:
446474
raise TypeError("Unsupported type!")
447475

@@ -464,6 +492,8 @@ def array_to_matrix(array: np.ndarray):
464492
return pyarray_copyto_matrix_i32(array)
465493
elif array.dtype == np.uint64:
466494
return pyarray_copyto_matrix_u64(array)
495+
elif array.dtype == np.int8:
496+
return pyarray_copyto_matrix_i8(array)
467497
else:
468498
raise TypeError("Unsupported type!")
469499

apis/python/test/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,15 @@ def move_local_index_to_new_location(index_uri):
351351
shutil.copytree(index_uri, copied_index_uri)
352352
shutil.rmtree(index_uri)
353353
return copied_index_uri
354+
355+
356+
def quantize_embeddings_int8(
357+
embeddings: np.ndarray,
358+
) -> np.ndarray:
359+
"""
360+
Quantizes embeddings to a lower precision.
361+
"""
362+
ranges = np.vstack((np.min(embeddings, axis=0), np.max(embeddings, axis=0)))
363+
starts = ranges[0, :]
364+
steps = (ranges[1, :] - ranges[0, :]) / 255
365+
return ((embeddings - starts) / steps - 128).astype(np.int8)

apis/python/test/test_ingestion.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,54 @@ def test_ingestion_numpy(tmp_path):
337337
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
338338

339339

340+
def test_ingestion_numpy_i8(tmp_path):
341+
source_uri = siftsmall_inputs_file
342+
queries_uri = siftsmall_query_file
343+
gt_uri = siftsmall_groundtruth_file
344+
index_uri = os.path.join(tmp_path, "array")
345+
k = 100
346+
partitions = 100
347+
nqueries = 100
348+
nprobe = 20
349+
350+
input_vectors = quantize_embeddings_int8(load_fvecs(source_uri))
351+
352+
queries = quantize_embeddings_int8(load_fvecs(queries_uri)).astype(np.float32)
353+
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
354+
355+
for index_type, index_class in zip(INDEXES, INDEX_CLASSES):
356+
# TODO(paris): Fix Vamana bug and re-enable:
357+
# RuntimeError: IndexError: index 100 is out of bounds for axis 0 with size 100
358+
if index_type == "VAMANA":
359+
continue
360+
361+
index_uri = os.path.join(tmp_path, f"array_{index_type}")
362+
index = ingest(
363+
index_type=index_type,
364+
index_uri=index_uri,
365+
input_vectors=input_vectors,
366+
partitions=partitions,
367+
)
368+
_, result = index.query(queries, k=k, nprobe=nprobe)
369+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
370+
371+
index_uri = move_local_index_to_new_location(index_uri)
372+
index_ram = index_class(uri=index_uri)
373+
_, result = index_ram.query(queries, k=k, nprobe=nprobe)
374+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
375+
376+
_, result = index_ram.query(
377+
queries,
378+
k=k,
379+
nprobe=nprobe,
380+
use_nuv_implementation=True,
381+
)
382+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
383+
384+
_, result = index_ram.query(queries, k=k, nprobe=nprobe, mode=Mode.LOCAL)
385+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
386+
387+
340388
def test_ingestion_multiple_workers(tmp_path):
341389
source_uri = siftsmall_inputs_file
342390
queries_uri = siftsmall_query_file

src/include/detail/scoring/l2_distance.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
135135
*/
136136
template <feature_vector V>
137137
requires std::same_as<typename V::value_type, float> ||
138-
std::same_as<typename V::value_type, uint8_t>
138+
std::same_as<typename V::value_type, uint8_t> ||
139+
std::same_as<typename V::value_type, int8_t>
139140
inline float unroll4_sum_of_squares(const V& a) {
140141
size_t size_a = size(a);
141142
size_t stop = 4 * (size_a / 4);
@@ -187,7 +188,8 @@ inline float unroll4_sum_of_squares(const V& a, const W& b) {
187188
*/
188189
template <feature_vector V, feature_vector W>
189190
requires std::same_as<typename V::value_type, float> &&
190-
std::same_as<typename W::value_type, uint8_t>
191+
(std::same_as<typename W::value_type, uint8_t> ||
192+
std::same_as<typename W::value_type, int8_t>)
191193
inline float unroll4_sum_of_squares(const V& a, const W& b) {
192194
size_t size_a = size(a);
193195
size_t stop = 4 * (size_a / 4);
@@ -212,8 +214,9 @@ inline float unroll4_sum_of_squares(const V& a, const W& b) {
212214
* Unrolled l2 distance between vector of uint8_t and vector of float
213215
*/
214216
template <feature_vector V, feature_vector W>
215-
requires std::same_as<typename V::value_type, uint8_t> &&
216-
std::same_as<typename W::value_type, float>
217+
requires(std::same_as<typename V::value_type, uint8_t> ||
218+
std::same_as<typename V::value_type, int8_t>) &&
219+
std::same_as<typename W::value_type, float>
217220
inline float unroll4_sum_of_squares(const V& a, const W& b) {
218221
size_t size_a = size(a);
219222
size_t stop = 4 * (size_a / 4);
@@ -238,8 +241,10 @@ inline float unroll4_sum_of_squares(const V& a, const W& b) {
238241
* Unrolled l2 distance between vector of uint8_t and vector of uint8_t
239242
*/
240243
template <feature_vector V, feature_vector W>
241-
requires std::same_as<typename V::value_type, uint8_t> &&
242-
std::same_as<typename W::value_type, uint8_t>
244+
requires(std::same_as<typename V::value_type, uint8_t> ||
245+
std::same_as<typename V::value_type, int8_t>) &&
246+
(std::same_as<typename W::value_type, uint8_t> ||
247+
std::same_as<typename W::value_type, int8_t>)
243248
inline float unroll4_sum_of_squares(const V& a, const W& b) {
244249
size_t size_a = size(a);
245250
size_t stop = 4 * (size_a / 4);

0 commit comments

Comments
 (0)