Skip to content

Commit e1d127c

Browse files
authored
[FEA] IVF-PQ Build Factories for Precomputed Centroids and Codebooks (rapidsai#1483)
IVF-PQ Index Build API Enhancements and PIMPL Refactoring ## Summary This PR adds new build APIs for IVF-PQ indices using precomputed centroids and implements a complete PIMPL refactoring with owning/view semantics ## Key Changes 1. New Build APIs for Precomputed Centroids Added `cuvs::neighbors::ivf_pq::build()` overloads that accept precomputed cluster centroids, PQ codebooks, and rotation matrices Enables building indices from pre-trained models without re-training Supports both device and host input data with automatic memory transfer 2. PIMPL Refactoring with Owning/View Semantics `owning_impl`: Owns centroid and codebook data (traditional behavior) `view_impl`: References external centroid data without copying Maintains identical search behavior with zero data copying 3. Enhanced Helper Functions Removed mutator functions that directly modify the state of the index. Instead we have helpers for the user to fetch and own the transformed data View indices avoid copying large centroid arrays Backward Compatibility: All existing APIs work unchanged [Updates] (as of 12/23/2025): The non-const getters that allow direct modification of the state of the index: `pq_centers()`, `centers()`, `centers_rot()` and `rotation_matrix()` have been removed from the interface and the user-facing class. Authors: - Tarang Jain (https://github.com/tarang-jain) - Corey J. Nolet (https://github.com/cjnolet) - Vyas Ramasubramani (https://github.com/vyasr) - Lorenzo Dematté (https://github.com/ldematte) - Divye Gala (https://github.com/divyegala) Approvers: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Micka (https://github.com/lowener) URL: rapidsai#1483
1 parent 1c67635 commit e1d127c

File tree

13 files changed

+1765
-461
lines changed

13 files changed

+1765
-461
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# =============================================================================
22
# cmake-format: off
3-
# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3+
# SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
44
# SPDX-License-Identifier: Apache-2.0
55
# cmake-format: on
66
cmake_minimum_required(VERSION 3.30.4 FATAL_ERROR)
@@ -492,6 +492,7 @@ if(NOT BUILD_CPU_ONLY)
492492
src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu
493493
src/neighbors/ivf_pq/detail/ivf_pq_build_extend_int8_t_int64_t.cu
494494
src/neighbors/ivf_pq/detail/ivf_pq_build_extend_uint8_t_int64_t.cu
495+
src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_int64_t.cu
495496
src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu
496497
src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu
497498
src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
diff --git a/faiss/gpu/impl/CuvsIVFPQ.cu b/faiss/gpu/impl/CuvsIVFPQ.cu
2+
index 1e2fef225..35b388147 100644
3+
--- a/faiss/gpu/impl/CuvsIVFPQ.cu
4+
+++ b/faiss/gpu/impl/CuvsIVFPQ.cu
5+
@@ -129,8 +129,14 @@ void CuvsIVFPQ::updateQuantizer(Index* quantizer) {
6+
7+
cuvs::neighbors::ivf_pq::helpers::reset_index(
8+
raft_handle, cuvs_index.get());
9+
+ auto mutable_rotation_matrix_view =
10+
+ raft::make_device_matrix_view<float, uint32_t>(
11+
+ const_cast<float*>(
12+
+ cuvs_index->rotation_matrix().data_handle()),
13+
+ cuvs_index->rotation_matrix().extent(0),
14+
+ cuvs_index->rotation_matrix().extent(1));
15+
cuvs::neighbors::ivf_pq::helpers::make_rotation_matrix(
16+
- raft_handle, cuvs_index.get(), false);
17+
+ raft_handle, mutable_rotation_matrix_view, false);
18+
19+
// If the index instance is a GpuIndexFlat, then we can use direct access to
20+
// the centroids within.
21+
@@ -149,22 +155,60 @@ void CuvsIVFPQ::updateQuantizer(Index* quantizer) {
22+
// as float32 and store locally
23+
gpuData->reconstruct(0, gpuData->getSize(), centroids);
24+
25+
- cuvs::neighbors::ivf_pq::helpers::set_centers(
26+
- raft_handle,
27+
- cuvs_index.get(),
28+
+ auto mutable_centers_view =
29+
+ raft::make_device_matrix_view<float, uint32_t>(
30+
+ const_cast<float*>(
31+
+ cuvs_index->centers().data_handle()),
32+
+ numLists_,
33+
+ cuvs_index->centers().extent(1));
34+
+ auto mutable_centers_rot_view =
35+
raft::make_device_matrix_view<float, uint32_t>(
36+
- centroids.data(), numLists_, dim_));
37+
+ const_cast<float*>(
38+
+ cuvs_index->centers_rot().data_handle()),
39+
+ cuvs_index->centers_rot().extent(0),
40+
+ cuvs_index->centers_rot().extent(1));
41+
+
42+
+ cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms(
43+
+ raft_handle,
44+
+ raft::make_const_mdspan(
45+
+ raft::make_device_matrix_view<float, uint32_t>(
46+
+ centroids.data(), numLists_, dim_)),
47+
+ mutable_centers_view);
48+
+ cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers(
49+
+ raft_handle,
50+
+ cuvs_index->centers(),
51+
+ cuvs_index->rotation_matrix(),
52+
+ mutable_centers_rot_view);
53+
} else {
54+
/// No reconstruct needed since the centers are already in float32
55+
// The FlatIndex keeps its data in float32, so we can merely
56+
// reference it
57+
auto centroids = gpuData->getVectorsFloat32Ref();
58+
59+
- cuvs::neighbors::ivf_pq::helpers::set_centers(
60+
- raft_handle,
61+
- cuvs_index.get(),
62+
+ auto mutable_centers_view =
63+
+ raft::make_device_matrix_view<float, uint32_t>(
64+
+ const_cast<float*>(
65+
+ cuvs_index->centers().data_handle()),
66+
+ numLists_,
67+
+ cuvs_index->centers().extent(1));
68+
+ auto mutable_centers_rot_view =
69+
raft::make_device_matrix_view<float, uint32_t>(
70+
- centroids.data(), numLists_, dim_));
71+
+ const_cast<float*>(
72+
+ cuvs_index->centers_rot().data_handle()),
73+
+ cuvs_index->centers_rot().extent(0),
74+
+ cuvs_index->centers_rot().extent(1));
75+
+
76+
+ cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms(
77+
+ raft_handle,
78+
+ raft::make_const_mdspan(
79+
+ raft::make_device_matrix_view<float, uint32_t>(
80+
+ centroids.data(), numLists_, dim_)),
81+
+ mutable_centers_view);
82+
+ cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers(
83+
+ raft_handle,
84+
+ cuvs_index->centers(),
85+
+ cuvs_index->rotation_matrix(),
86+
+ mutable_centers_rot_view);
87+
}
88+
} else {
89+
DeviceTensor<float, 2, true> centroids(
90+
@@ -180,11 +224,30 @@ void CuvsIVFPQ::updateQuantizer(Index* quantizer) {
91+
92+
centroids.copyFrom(vecs, stream);
93+
94+
- cuvs::neighbors::ivf_pq::helpers::set_centers(
95+
- raft_handle,
96+
- cuvs_index.get(),
97+
+ // Create mutable views for output parameters
98+
+ auto mutable_centers_view =
99+
+ raft::make_device_matrix_view<float, uint32_t>(
100+
+ const_cast<float*>(cuvs_index->centers().data_handle()),
101+
+ numLists_,
102+
+ cuvs_index->centers().extent(1));
103+
+ auto mutable_centers_rot_view =
104+
raft::make_device_matrix_view<float, uint32_t>(
105+
- centroids.data(), numLists_, dim_));
106+
+ const_cast<float*>(
107+
+ cuvs_index->centers_rot().data_handle()),
108+
+ cuvs_index->centers_rot().extent(0),
109+
+ cuvs_index->centers_rot().extent(1));
110+
+
111+
+ cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms(
112+
+ raft_handle,
113+
+ raft::make_const_mdspan(
114+
+ raft::make_device_matrix_view<float, uint32_t>(
115+
+ centroids.data(), numLists_, dim_)),
116+
+ mutable_centers_view);
117+
+ cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers(
118+
+ raft_handle,
119+
+ cuvs_index->centers(),
120+
+ cuvs_index->rotation_matrix(),
121+
+ mutable_centers_rot_view);
122+
}
123+
124+
setPQCentroids_();
125+
@@ -520,7 +583,7 @@ void CuvsIVFPQ::setPQCentroids_() {
126+
auto stream = resources_->getDefaultStreamCurrentDevice();
127+
128+
raft::copy(
129+
- cuvs_index->pq_centers().data_handle(),
130+
+ const_cast<float*>(cuvs_index->pq_centers().data_handle()),
131+
pqCentroidsInnermostCode_.data(),
132+
pqCentroidsInnermostCode_.numElements(),
133+
stream);

cpp/cmake/patches/faiss_override.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
"file" : "${current_json_dir}/faiss-1.13-cuvs-25.12.diff",
1010
"issue" : "Multiple fixes for cuVS and RMM compatibility",
1111
"fixed_in" : ""
12+
},
13+
{
14+
"file" : "${current_json_dir}/faiss-1.13-cuvs-26.02.diff",
15+
"issue" : "Multiple fixes for cuVS and RMM compatibility",
16+
"fixed_in" : ""
1217
}
1318
]
1419
}

0 commit comments

Comments
 (0)