Skip to content

Commit 19b050a

Browse files
authored
[FEA] C + Python API for IVF-PQ Build Factories with Precomputed Centroids (rapidsai#1664)
To be merged after rapidsai#1483 Authors: - Tarang Jain (https://github.com/tarang-jain) Approvers: - Ben Frederickson (https://github.com/benfred) URL: rapidsai#1664
1 parent 234a1af commit 19b050a

File tree

6 files changed

+471
-1
lines changed

6 files changed

+471
-1
lines changed

c/include/cuvs/neighbors/ivf_pq.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ cuvsError_t cuvsIvfPqIndexGetPqLen(cuvsIvfPqIndex_t index, int64_t* pq_len);
278278
*/
279279
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManagedTensor* centers);
280280

281+
/**
282+
* @brief Get the padded cluster centers [n_lists, dim_ext]
283+
* where dim_ext = round_up(dim + 1, 8)
284+
*
285+
* This returns the full padded centers as a contiguous array, suitable for
286+
* use with cuvsIvfPqBuildPrecomputed.
287+
*
288+
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
289+
* @param[out] centers Output tensor that will be populated with a non-owning view of the data
290+
* @return cuvsError_t
291+
*/
292+
cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index, DLManagedTensor* centers);
293+
281294
/**
282295
* @brief Get the PQ cluster centers
283296
*
@@ -290,6 +303,28 @@ cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManagedTensor* ce
290303
*/
291304
cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index, DLManagedTensor* pq_centers);
292305

306+
/**
307+
* @brief Get the rotated cluster centers [n_lists, rot_dim]
308+
* where rot_dim = pq_len * pq_dim
309+
*
310+
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
311+
* @param[out] centers_rot Output tensor that will be populated with a non-owning view of the data
312+
* @return cuvsError_t
313+
*/
314+
cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index, DLManagedTensor* centers_rot);
315+
316+
/**
317+
* @brief Get the rotation matrix [rot_dim, dim]
318+
* Transform matrix (original space -> rotated padded space)
319+
*
320+
* @param[in] index cuvsIvfPqIndex_t Built Ivf-Pq index
321+
* @param[out] rotation_matrix Output tensor that will be populated with a non-owning view of the
322+
* data
323+
* @return cuvsError_t
324+
*/
325+
cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
326+
DLManagedTensor* rotation_matrix);
327+
293328
/**
294329
* @brief Get the sizes of each list
295330
*
@@ -389,6 +424,44 @@ cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
389424
cuvsIvfPqIndexParams_t params,
390425
DLManagedTensor* dataset,
391426
cuvsIvfPqIndex_t index);
427+
428+
/**
429+
* @brief Build a view-type IVF-PQ index from device memory precomputed centroids and codebook.
430+
*
431+
* This function creates a non-owning index that stores a reference to the provided device data.
432+
* All parameters must be provided with correct extents. The caller is responsible for ensuring
433+
* the lifetime of the input data exceeds the lifetime of the returned index.
434+
*
435+
* The index_params must be consistent with the provided matrices. Specifically:
436+
* - index_params.codebook_kind determines the expected shape of pq_centers
437+
* - index_params.metric will be stored in the index
438+
* - index_params.conservative_memory_allocation will be stored in the index
439+
* The function will verify consistency between index_params, dim, and the matrix extents.
440+
*
441+
* @param[in] res cuvsResources_t opaque C handle
442+
* @param[in] params cuvsIvfPqIndexParams_t used to configure the index (must be consistent with
443+
* matrices)
444+
* @param[in] dim dimensionality of the input data
445+
* @param[in] pq_centers PQ codebook on device memory with required shape:
446+
* - codebook_kind PER_SUBSPACE: [pq_dim, pq_len, pq_book_size]
447+
* - codebook_kind PER_CLUSTER: [n_lists, pq_len, pq_book_size]
448+
* @param[in] centers Cluster centers in the original space [n_lists, dim_ext]
449+
* where dim_ext = round_up(dim + 1, 8)
450+
* @param[in] centers_rot Rotated cluster centers [n_lists, rot_dim]
451+
* where rot_dim = pq_len * pq_dim
452+
* @param[in] rotation_matrix Transform matrix (original space -> rotated padded space) [rot_dim,
453+
* dim]
454+
* @param[out] index cuvsIvfPqIndex_t Newly built view-type IVF-PQ index
455+
* @return cuvsError_t
456+
*/
457+
cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
458+
cuvsIvfPqIndexParams_t params,
459+
uint32_t dim,
460+
DLManagedTensor* pq_centers,
461+
DLManagedTensor* centers,
462+
DLManagedTensor* centers_rot,
463+
DLManagedTensor* rotation_matrix,
464+
cuvsIvfPqIndex_t index);
392465
/**
393466
* @}
394467
*/

c/src/neighbors/ivf_pq.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,35 @@ void _get_centers(cuvsIvfPqIndex index, DLManagedTensor* output)
161161
cuvs::core::to_dlpack(strided_centers, output);
162162
}
163163

164+
template <typename IdxT>
165+
void _get_centers_padded(cuvsIvfPqIndex index, DLManagedTensor* output)
166+
{
167+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
168+
// Return the full padded centers [n_lists, dim_ext] as a contiguous array
169+
cuvs::core::to_dlpack(index_ptr->centers(), output);
170+
}
171+
164172
template <typename IdxT>
165173
void _get_pq_centers(cuvsIvfPqIndex index, DLManagedTensor* centers)
166174
{
167175
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
168176
cuvs::core::to_dlpack(index_ptr->pq_centers(), centers);
169177
}
170178

179+
template <typename IdxT>
180+
void _get_centers_rot(cuvsIvfPqIndex index, DLManagedTensor* centers_rot)
181+
{
182+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
183+
cuvs::core::to_dlpack(index_ptr->centers_rot(), centers_rot);
184+
}
185+
186+
template <typename IdxT>
187+
void _get_rotation_matrix(cuvsIvfPqIndex index, DLManagedTensor* rotation_matrix)
188+
{
189+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
190+
cuvs::core::to_dlpack(index_ptr->rotation_matrix(), rotation_matrix);
191+
}
192+
171193
template <typename IdxT>
172194
void _get_list_sizes(cuvsIvfPqIndex index, DLManagedTensor* list_sizes)
173195
{
@@ -355,6 +377,12 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
355377
return cuvs::core::translate_exceptions([=] {
356378
auto vectors = new_vectors->dl_tensor;
357379

380+
// Set the index dtype if not already set (e.g., for view-type indices built from precomputed data)
381+
if (index->dtype.code == 0 && index->dtype.bits == 0) {
382+
index->dtype.code = vectors.dtype.code;
383+
index->dtype.bits = vectors.dtype.bits;
384+
}
385+
358386
if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 32) {
359387
_extend<float, int64_t>(res, new_vectors, new_indices, *index);
360388
} else if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 16) {
@@ -422,12 +450,92 @@ extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManage
422450
return cuvs::core::translate_exceptions([=] { _get_centers<int64_t>(*index, centers); });
423451
}
424452

453+
extern "C" cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index,
454+
DLManagedTensor* centers)
455+
{
456+
return cuvs::core::translate_exceptions([=] { _get_centers_padded<int64_t>(*index, centers); });
457+
}
458+
425459
extern "C" cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index,
426460
DLManagedTensor* pq_centers)
427461
{
428462
return cuvs::core::translate_exceptions([=] { _get_pq_centers<int64_t>(*index, pq_centers); });
429463
}
430464

465+
extern "C" cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index,
466+
DLManagedTensor* centers_rot)
467+
{
468+
return cuvs::core::translate_exceptions([=] { _get_centers_rot<int64_t>(*index, centers_rot); });
469+
}
470+
471+
extern "C" cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
472+
DLManagedTensor* rotation_matrix)
473+
{
474+
return cuvs::core::translate_exceptions(
475+
[=] { _get_rotation_matrix<int64_t>(*index, rotation_matrix); });
476+
}
477+
478+
extern "C" cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
479+
cuvsIvfPqIndexParams_t params,
480+
uint32_t dim,
481+
DLManagedTensor* pq_centers_tensor,
482+
DLManagedTensor* centers_tensor,
483+
DLManagedTensor* centers_rot_tensor,
484+
DLManagedTensor* rotation_matrix_tensor,
485+
cuvsIvfPqIndex_t index)
486+
{
487+
return cuvs::core::translate_exceptions([=] {
488+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
489+
490+
auto build_params = cuvs::neighbors::ivf_pq::index_params();
491+
convert_c_index_params(*params, &build_params);
492+
493+
// Verify all tensors are on device
494+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(pq_centers_tensor->dl_tensor),
495+
"pq_centers should have device compatible memory");
496+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(centers_tensor->dl_tensor),
497+
"centers should have device compatible memory");
498+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(centers_rot_tensor->dl_tensor),
499+
"centers_rot should have device compatible memory");
500+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(rotation_matrix_tensor->dl_tensor),
501+
"rotation_matrix should have device compatible memory");
502+
503+
// Verify all tensors are float32
504+
auto& pq_centers_dl = pq_centers_tensor->dl_tensor;
505+
auto& centers_dl = centers_tensor->dl_tensor;
506+
auto& centers_rot_dl = centers_rot_tensor->dl_tensor;
507+
auto& rotation_matrix_dl = rotation_matrix_tensor->dl_tensor;
508+
509+
RAFT_EXPECTS(pq_centers_dl.dtype.code == kDLFloat && pq_centers_dl.dtype.bits == 32,
510+
"pq_centers must be float32");
511+
RAFT_EXPECTS(centers_dl.dtype.code == kDLFloat && centers_dl.dtype.bits == 32,
512+
"centers must be float32");
513+
RAFT_EXPECTS(centers_rot_dl.dtype.code == kDLFloat && centers_rot_dl.dtype.bits == 32,
514+
"centers_rot must be float32");
515+
RAFT_EXPECTS(rotation_matrix_dl.dtype.code == kDLFloat && rotation_matrix_dl.dtype.bits == 32,
516+
"rotation_matrix must be float32");
517+
518+
// Convert DLPack tensors to mdspan views
519+
using pq_centers_mdspan_type = raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major>;
520+
using matrix_mdspan_type = raft::device_matrix_view<const float, uint32_t, raft::row_major>;
521+
522+
auto pq_centers_mds = cuvs::core::from_dlpack<pq_centers_mdspan_type>(pq_centers_tensor);
523+
auto centers_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_tensor);
524+
auto centers_rot_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_rot_tensor);
525+
auto rotation_matrix_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(rotation_matrix_tensor);
526+
527+
// Build the index
528+
auto* idx = new cuvs::neighbors::ivf_pq::index<int64_t>(
529+
cuvs::neighbors::ivf_pq::build(
530+
*res_ptr, build_params, dim, pq_centers_mds, centers_mds, centers_rot_mds, rotation_matrix_mds));
531+
532+
index->addr = reinterpret_cast<uintptr_t>(idx);
533+
// Leave dtype unset (0) - it will be set when extend() is called with actual data
534+
index->dtype.code = 0;
535+
index->dtype.bits = 0;
536+
});
537+
}
538+
431539
extern "C" cuvsError_t cuvsIvfPqIndexGetListSizes(cuvsIvfPqIndex_t index,
432540
DLManagedTensor* list_sizes)
433541
{

python/cuvs/cuvs/neighbors/ivf_pq/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44

@@ -7,6 +7,7 @@
77
IndexParams,
88
SearchParams,
99
build,
10+
build_precomputed,
1011
extend,
1112
load,
1213
save,
@@ -18,6 +19,7 @@
1819
"IndexParams",
1920
"SearchParams",
2021
"build",
22+
"build_precomputed",
2123
"extend",
2224
"load",
2325
"save",

python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pxd

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,21 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:
9292
cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index,
9393
DLManagedTensor * centers)
9494

95+
cuvsError_t cuvsIvfPqIndexGetCentersPadded(cuvsIvfPqIndex_t index,
96+
DLManagedTensor * centers)
97+
9598
cuvsError_t cuvsIvfPqIndexGetListSizes(cuvsIvfPqIndex_t index,
9699
DLManagedTensor * list_sizes)
97100

98101
cuvsError_t cuvsIvfPqIndexGetPqCenters(cuvsIvfPqIndex_t index,
99102
DLManagedTensor * centers)
100103

104+
cuvsError_t cuvsIvfPqIndexGetCentersRot(cuvsIvfPqIndex_t index,
105+
DLManagedTensor * centers_rot)
106+
107+
cuvsError_t cuvsIvfPqIndexGetRotationMatrix(cuvsIvfPqIndex_t index,
108+
DLManagedTensor * rotation_matrix)
109+
101110
cuvsError_t cuvsIvfPqIndexUnpackContiguousListData(cuvsResources_t res,
102111
cuvsIvfPqIndex_t index,
103112
DLManagedTensor* out,
@@ -113,6 +122,15 @@ cdef extern from "cuvs/neighbors/ivf_pq.h" nogil:
113122
DLManagedTensor* dataset,
114123
cuvsIvfPqIndex_t index)
115124

125+
cuvsError_t cuvsIvfPqBuildPrecomputed(cuvsResources_t res,
126+
cuvsIvfPqIndexParams_t params,
127+
uint32_t dim,
128+
DLManagedTensor* pq_centers,
129+
DLManagedTensor* centers,
130+
DLManagedTensor* centers_rot,
131+
DLManagedTensor* rotation_matrix,
132+
cuvsIvfPqIndex_t index)
133+
116134
cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
117135
cuvsIvfPqSearchParams* params,
118136
cuvsIvfPqIndex_t index,

0 commit comments

Comments
 (0)