@@ -161,13 +161,35 @@ void _get_centers(cuvsIvfPqIndex index, DLManagedTensor* output)
161161 cuvs::core::to_dlpack (strided_centers, output);
162162}
163163
164+ template <typename IdxT>
165+ void _get_centers_padded (cuvsIvfPqIndex index, DLManagedTensor* output)
166+ {
167+ auto index_ptr = reinterpret_cast <cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr );
168+ // Return the full padded centers [n_lists, dim_ext] as a contiguous array
169+ cuvs::core::to_dlpack (index_ptr->centers (), output);
170+ }
171+
164172template <typename IdxT>
165173void _get_pq_centers (cuvsIvfPqIndex index, DLManagedTensor* centers)
166174{
167175 auto index_ptr = reinterpret_cast <cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr );
168176 cuvs::core::to_dlpack (index_ptr->pq_centers (), centers);
169177}
170178
179+ template <typename IdxT>
180+ void _get_centers_rot (cuvsIvfPqIndex index, DLManagedTensor* centers_rot)
181+ {
182+ auto index_ptr = reinterpret_cast <cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr );
183+ cuvs::core::to_dlpack (index_ptr->centers_rot (), centers_rot);
184+ }
185+
186+ template <typename IdxT>
187+ void _get_rotation_matrix (cuvsIvfPqIndex index, DLManagedTensor* rotation_matrix)
188+ {
189+ auto index_ptr = reinterpret_cast <cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr );
190+ cuvs::core::to_dlpack (index_ptr->rotation_matrix (), rotation_matrix);
191+ }
192+
171193template <typename IdxT>
172194void _get_list_sizes (cuvsIvfPqIndex index, DLManagedTensor* list_sizes)
173195{
@@ -355,6 +377,12 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
355377 return cuvs::core::translate_exceptions ([=] {
356378 auto vectors = new_vectors->dl_tensor ;
357379
380+ // Set the index dtype if not already set (e.g., for view-type indices built from precomputed data)
381+ if (index->dtype .code == 0 && index->dtype .bits == 0 ) {
382+ index->dtype .code = vectors.dtype .code ;
383+ index->dtype .bits = vectors.dtype .bits ;
384+ }
385+
358386 if (vectors.dtype .code == kDLFloat && vectors.dtype .bits == 32 ) {
359387 _extend<float , int64_t >(res, new_vectors, new_indices, *index);
360388 } else if (vectors.dtype .code == kDLFloat && vectors.dtype .bits == 16 ) {
@@ -422,12 +450,92 @@ extern "C" cuvsError_t cuvsIvfPqIndexGetCenters(cuvsIvfPqIndex_t index, DLManage
422450 return cuvs::core::translate_exceptions ([=] { _get_centers<int64_t >(*index, centers); });
423451}
424452
453+ extern " C" cuvsError_t cuvsIvfPqIndexGetCentersPadded (cuvsIvfPqIndex_t index,
454+ DLManagedTensor* centers)
455+ {
456+ return cuvs::core::translate_exceptions ([=] { _get_centers_padded<int64_t >(*index, centers); });
457+ }
458+
425459extern " C" cuvsError_t cuvsIvfPqIndexGetPqCenters (cuvsIvfPqIndex_t index,
426460 DLManagedTensor* pq_centers)
427461{
428462 return cuvs::core::translate_exceptions ([=] { _get_pq_centers<int64_t >(*index, pq_centers); });
429463}
430464
465+ extern " C" cuvsError_t cuvsIvfPqIndexGetCentersRot (cuvsIvfPqIndex_t index,
466+ DLManagedTensor* centers_rot)
467+ {
468+ return cuvs::core::translate_exceptions ([=] { _get_centers_rot<int64_t >(*index, centers_rot); });
469+ }
470+
471+ extern " C" cuvsError_t cuvsIvfPqIndexGetRotationMatrix (cuvsIvfPqIndex_t index,
472+ DLManagedTensor* rotation_matrix)
473+ {
474+ return cuvs::core::translate_exceptions (
475+ [=] { _get_rotation_matrix<int64_t >(*index, rotation_matrix); });
476+ }
477+
478+ extern " C" cuvsError_t cuvsIvfPqBuildPrecomputed (cuvsResources_t res,
479+ cuvsIvfPqIndexParams_t params,
480+ uint32_t dim,
481+ DLManagedTensor* pq_centers_tensor,
482+ DLManagedTensor* centers_tensor,
483+ DLManagedTensor* centers_rot_tensor,
484+ DLManagedTensor* rotation_matrix_tensor,
485+ cuvsIvfPqIndex_t index)
486+ {
487+ return cuvs::core::translate_exceptions ([=] {
488+ auto res_ptr = reinterpret_cast <raft::resources*>(res);
489+
490+ auto build_params = cuvs::neighbors::ivf_pq::index_params ();
491+ convert_c_index_params (*params, &build_params);
492+
493+ // Verify all tensors are on device
494+ RAFT_EXPECTS (cuvs::core::is_dlpack_device_compatible (pq_centers_tensor->dl_tensor ),
495+ " pq_centers should have device compatible memory" );
496+ RAFT_EXPECTS (cuvs::core::is_dlpack_device_compatible (centers_tensor->dl_tensor ),
497+ " centers should have device compatible memory" );
498+ RAFT_EXPECTS (cuvs::core::is_dlpack_device_compatible (centers_rot_tensor->dl_tensor ),
499+ " centers_rot should have device compatible memory" );
500+ RAFT_EXPECTS (cuvs::core::is_dlpack_device_compatible (rotation_matrix_tensor->dl_tensor ),
501+ " rotation_matrix should have device compatible memory" );
502+
503+ // Verify all tensors are float32
504+ auto & pq_centers_dl = pq_centers_tensor->dl_tensor ;
505+ auto & centers_dl = centers_tensor->dl_tensor ;
506+ auto & centers_rot_dl = centers_rot_tensor->dl_tensor ;
507+ auto & rotation_matrix_dl = rotation_matrix_tensor->dl_tensor ;
508+
509+ RAFT_EXPECTS (pq_centers_dl.dtype .code == kDLFloat && pq_centers_dl.dtype .bits == 32 ,
510+ " pq_centers must be float32" );
511+ RAFT_EXPECTS (centers_dl.dtype .code == kDLFloat && centers_dl.dtype .bits == 32 ,
512+ " centers must be float32" );
513+ RAFT_EXPECTS (centers_rot_dl.dtype .code == kDLFloat && centers_rot_dl.dtype .bits == 32 ,
514+ " centers_rot must be float32" );
515+ RAFT_EXPECTS (rotation_matrix_dl.dtype .code == kDLFloat && rotation_matrix_dl.dtype .bits == 32 ,
516+ " rotation_matrix must be float32" );
517+
518+ // Convert DLPack tensors to mdspan views
519+ using pq_centers_mdspan_type = raft::device_mdspan<const float , raft::extent_3d<uint32_t >, raft::row_major>;
520+ using matrix_mdspan_type = raft::device_matrix_view<const float , uint32_t , raft::row_major>;
521+
522+ auto pq_centers_mds = cuvs::core::from_dlpack<pq_centers_mdspan_type>(pq_centers_tensor);
523+ auto centers_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_tensor);
524+ auto centers_rot_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(centers_rot_tensor);
525+ auto rotation_matrix_mds = cuvs::core::from_dlpack<matrix_mdspan_type>(rotation_matrix_tensor);
526+
527+ // Build the index
528+ auto * idx = new cuvs::neighbors::ivf_pq::index<int64_t >(
529+ cuvs::neighbors::ivf_pq::build (
530+ *res_ptr, build_params, dim, pq_centers_mds, centers_mds, centers_rot_mds, rotation_matrix_mds));
531+
532+ index->addr = reinterpret_cast <uintptr_t >(idx);
533+ // Leave dtype unset (0) - it will be set when extend() is called with actual data
534+ index->dtype .code = 0 ;
535+ index->dtype .bits = 0 ;
536+ });
537+ }
538+
431539extern " C" cuvsError_t cuvsIvfPqIndexGetListSizes (cuvsIvfPqIndex_t index,
432540 DLManagedTensor* list_sizes)
433541{
0 commit comments