Skip to content

Commit 648bfca

Browse files
authored
Merge pull request rapidsai#1737 from rapidsai/release/26.02
Forward-merge release/26.02 into main
2 parents 7f539ae + aecbaa9 commit 648bfca

File tree

15 files changed

+584
-2
lines changed

15 files changed

+584
-2
lines changed

c/include/cuvs/neighbors/ivf_pq.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,29 @@ cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,
601601
/**
602602
* @}
603603
*/
604+
605+
/**
606+
* @defgroup ivf_pq_c_index_transform IVF-PQ index transform
607+
* @{
608+
*/
609+
/**
610+
* @brief Transform the input data by applying pq-encoding
611+
*
612+
* @param[in] res cuvsResources_t opaque C handle
613+
* @param[in] index IVF-PQ index
614+
* @param[in] input_dataset DLManagedTensor* vectors to transform
615+
* @param[out] output_labels DLManagedTensor* Vector of cluster labels for each vector in the input
616+
* @param[out] output_dataset DLManagedTensor* input vectors after pq-encoding
617+
* @return cuvsError_t
618+
*/
619+
cuvsError_t cuvsIvfPqTransform(cuvsResources_t res,
620+
cuvsIvfPqIndex_t index,
621+
DLManagedTensor* input_dataset,
622+
DLManagedTensor* output_labels,
623+
DLManagedTensor* output_dataset);
624+
/**
625+
* @}
626+
*/
604627
#ifdef __cplusplus
605628
}
606629
#endif

c/src/neighbors/ivf_pq.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,28 @@ void _get_list_indices(cuvsIvfPqIndex index,
230230
cuvs::core::to_dlpack(list.indices.view(), out_labels);
231231
}
232232
}
233+
234+
template <typename T, typename IdxT>
235+
void _transform(cuvsResources_t res,
236+
cuvsIvfPqIndex index,
237+
DLManagedTensor* input_dataset,
238+
DLManagedTensor* output_labels,
239+
DLManagedTensor* output_dataset)
240+
{
241+
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
242+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
243+
244+
using input_mdspan_type = raft::device_matrix_view<const T, IdxT, raft::row_major>;
245+
using labels_mdspan_type = raft::device_vector_view<uint32_t, IdxT, raft::row_major>;
246+
using output_mdspan_type = raft::device_matrix_view<uint8_t, IdxT, raft::row_major>;
247+
248+
auto input_mds = cuvs::core::from_dlpack<input_mdspan_type>(input_dataset);
249+
auto labels_mds = cuvs::core::from_dlpack<labels_mdspan_type>(output_labels);
250+
auto output_mds = cuvs::core::from_dlpack<output_mdspan_type>(output_dataset);
251+
252+
cuvs::neighbors::ivf_pq::transform(*res_ptr, *index_ptr, input_mds, labels_mds, output_mds);
253+
}
254+
233255
} // namespace
234256

235257
extern "C" cuvsError_t cuvsIvfPqIndexCreate(cuvsIvfPqIndex_t* index)
@@ -569,3 +591,43 @@ extern "C" cuvsError_t cuvsIvfPqIndexGetListIndices(cuvsIvfPqIndex_t index,
569591
return cuvs::core::translate_exceptions(
570592
[=] { _get_list_indices<int64_t>(*index, label, out_labels); });
571593
}
594+
595+
extern "C" cuvsError_t cuvsIvfPqTransform(cuvsResources_t res,
596+
cuvsIvfPqIndex_t index,
597+
DLManagedTensor* input_dataset,
598+
DLManagedTensor* output_labels,
599+
DLManagedTensor* output_dataset) {
600+
return cuvs::core::translate_exceptions(
601+
[=] {
602+
// Verify all tensors are on device
603+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(input_dataset->dl_tensor),
604+
"input_dataset should have device compatible memory");
605+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(output_labels->dl_tensor),
606+
"output_labels should have device compatible memory");
607+
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(output_dataset->dl_tensor),
608+
"output_dataset should have device compatible memory");
609+
610+
// Verify dtypes of inputs
611+
auto& output_labels_dl = output_labels->dl_tensor;
612+
RAFT_EXPECTS(output_labels_dl.dtype.code == kDLUInt && output_labels_dl.dtype.bits == 32,
613+
"output_labels must have a uint32 dtype ");
614+
auto& output_dataset_dl = output_dataset->dl_tensor;
615+
RAFT_EXPECTS(output_dataset_dl.dtype.code == kDLUInt && output_dataset_dl.dtype.bits == 8,
616+
"output_dataset must have a uint8 dtype");
617+
618+
auto & dataset = input_dataset->dl_tensor;
619+
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
620+
_transform<float, int64_t>(res, *index, input_dataset, output_labels, output_dataset);
621+
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
622+
_transform<half, int64_t>(res, *index, input_dataset, output_labels, output_dataset);
623+
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
624+
_transform<int8_t, int64_t>(res, *index, input_dataset, output_labels, output_dataset);
625+
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
626+
_transform<uint8_t, int64_t>(res, *index, input_dataset, output_labels, output_dataset);
627+
} else {
628+
RAFT_FAIL("Unsupported input_dataset DLtensor dtype: %d and bits: %d",
629+
dataset.dtype.code,
630+
dataset.dtype.bits);
631+
}
632+
});
633+
}

cpp/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,10 @@ if(NOT BUILD_CPU_ONLY)
513513
src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu
514514
src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu
515515
src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu
516+
src/neighbors/ivf_pq/detail/ivf_pq_transform_float_int64_t.cu
517+
src/neighbors/ivf_pq/detail/ivf_pq_transform_half_int64_t.cu
518+
src/neighbors/ivf_pq/detail/ivf_pq_transform_int8_t_int64_t.cu
519+
src/neighbors/ivf_pq/detail/ivf_pq_transform_uint8_t_int64_t.cu
516520
src/neighbors/knn_merge_parts.cu
517521
src/neighbors/nn_descent.cu
518522
src/neighbors/nn_descent_float.cu

cpp/include/cuvs/neighbors/ivf_pq.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,48 @@ void search(raft::resources const& handle,
19631963
* @}
19641964
*/
19651965

1966+
/**
1967+
* @defgroup ivf_pq_cpp_transform IVF-PQ index transform
1968+
* @{
1969+
*/
1970+
/**
1971+
* @brief Transform a dataset by applying pq-encoding to each vector
1972+
*
1973+
* @param[in] handle
1974+
* @param[in] index ivf-pq constructed index
1975+
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, index.dim()]
1976+
* @param[out] output_labels a device vector view [n_rows] that will get populaterd with the
1977+
* cluster ids (labels) for each vector in the input dataset
1978+
* @param[out] output_dataset a device matrix view [n_rows, ceildiv(index.pq_dim() *
1979+
* index.pq_bits(), 8)]] that will get populated with the pq-encoded dataset
1980+
*/
1981+
void transform(raft::resources const& handle,
1982+
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
1983+
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
1984+
raft::device_vector_view<uint32_t, int64_t> output_labels,
1985+
raft::device_matrix_view<uint8_t, int64_t> output_dataset);
1986+
/** @copydoc transform */
1987+
void transform(raft::resources const& handle,
1988+
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
1989+
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
1990+
raft::device_vector_view<uint32_t, int64_t> output_labels,
1991+
raft::device_matrix_view<uint8_t, int64_t> output_dataset);
1992+
/** @copydoc transform */
1993+
void transform(raft::resources const& handle,
1994+
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
1995+
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> dataset,
1996+
raft::device_vector_view<uint32_t, int64_t> output_labels,
1997+
raft::device_matrix_view<uint8_t, int64_t> output_dataset);
1998+
/** @copydoc transform */
1999+
void transform(raft::resources const& handle,
2000+
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
2001+
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> dataset,
2002+
raft::device_vector_view<uint32_t, int64_t> output_labels,
2003+
raft::device_matrix_view<uint8_t, int64_t> output_dataset);
2004+
/**
2005+
* @}
2006+
*/
2007+
19662008
/**
19672009
* @defgroup ivf_pq_cpp_serialize IVF-PQ index serialize
19682010
* @{

cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
import datetime
@@ -59,6 +59,23 @@
5959
}
6060
"""
6161

62+
transform_include_macro = """
63+
#include "../ivf_pq_transform.cuh"
64+
"""
65+
66+
transform_macro = """
67+
#define CUVS_INST_IVF_PQ_TRANSFORM(T, IdxT) \\
68+
void transform(raft::resources const& handle, \\
69+
const cuvs::neighbors::ivf_pq::index<IdxT>& index, \\
70+
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \\
71+
raft::device_vector_view<uint32_t, IdxT> output_labels, \\
72+
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> output_dataset) \\
73+
{ \\
74+
cuvs::neighbors::ivf_pq::detail::transform( \\
75+
handle, index, dataset, output_labels, output_dataset); \\
76+
}
77+
"""
78+
6279
macros = dict(
6380
build_extend=dict(
6481
include=build_include_macro,
@@ -70,6 +87,11 @@
7087
definition=search_macro,
7188
name="CUVS_INST_IVF_PQ_SEARCH",
7289
),
90+
transform=dict(
91+
include=transform_include_macro,
92+
definition=transform_macro,
93+
name="CUVS_INST_IVF_PQ_TRANSFORM",
94+
),
7395
)
7496

7597
for type_path, (T, IdxT) in types.items():
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* NOTE: this file is generated by generate_ivf_pq.py
8+
*
9+
* Make changes there and run in this directory:
10+
*
11+
* > python generate_ivf_pq.py
12+
*
13+
*/
14+
15+
#include <cuvs/neighbors/ivf_pq.hpp>
16+
17+
#include "../ivf_pq_transform.cuh"
18+
19+
namespace cuvs::neighbors::ivf_pq {
20+
21+
#define CUVS_INST_IVF_PQ_TRANSFORM(T, IdxT) \
22+
void transform(raft::resources const& handle, \
23+
const cuvs::neighbors::ivf_pq::index<IdxT>& index, \
24+
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
25+
raft::device_vector_view<uint32_t, IdxT> output_labels, \
26+
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> output_dataset) \
27+
{ \
28+
cuvs::neighbors::ivf_pq::detail::transform( \
29+
handle, index, dataset, output_labels, output_dataset); \
30+
}
31+
CUVS_INST_IVF_PQ_TRANSFORM(float, int64_t);
32+
33+
#undef CUVS_INST_IVF_PQ_TRANSFORM
34+
35+
} // namespace cuvs::neighbors::ivf_pq
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* NOTE: this file is generated by generate_ivf_pq.py
8+
*
9+
* Make changes there and run in this directory:
10+
*
11+
* > python generate_ivf_pq.py
12+
*
13+
*/
14+
15+
#include <cuvs/neighbors/ivf_pq.hpp>
16+
17+
#include "../ivf_pq_transform.cuh"
18+
19+
namespace cuvs::neighbors::ivf_pq {
20+
21+
#define CUVS_INST_IVF_PQ_TRANSFORM(T, IdxT) \
22+
void transform(raft::resources const& handle, \
23+
const cuvs::neighbors::ivf_pq::index<IdxT>& index, \
24+
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
25+
raft::device_vector_view<uint32_t, IdxT> output_labels, \
26+
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> output_dataset) \
27+
{ \
28+
cuvs::neighbors::ivf_pq::detail::transform( \
29+
handle, index, dataset, output_labels, output_dataset); \
30+
}
31+
CUVS_INST_IVF_PQ_TRANSFORM(half, int64_t);
32+
33+
#undef CUVS_INST_IVF_PQ_TRANSFORM
34+
35+
} // namespace cuvs::neighbors::ivf_pq
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* NOTE: this file is generated by generate_ivf_pq.py
8+
*
9+
* Make changes there and run in this directory:
10+
*
11+
* > python generate_ivf_pq.py
12+
*
13+
*/
14+
15+
#include <cuvs/neighbors/ivf_pq.hpp>
16+
17+
#include "../ivf_pq_transform.cuh"
18+
19+
namespace cuvs::neighbors::ivf_pq {
20+
21+
#define CUVS_INST_IVF_PQ_TRANSFORM(T, IdxT) \
22+
void transform(raft::resources const& handle, \
23+
const cuvs::neighbors::ivf_pq::index<IdxT>& index, \
24+
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
25+
raft::device_vector_view<uint32_t, IdxT> output_labels, \
26+
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> output_dataset) \
27+
{ \
28+
cuvs::neighbors::ivf_pq::detail::transform( \
29+
handle, index, dataset, output_labels, output_dataset); \
30+
}
31+
CUVS_INST_IVF_PQ_TRANSFORM(int8_t, int64_t);
32+
33+
#undef CUVS_INST_IVF_PQ_TRANSFORM
34+
35+
} // namespace cuvs::neighbors::ivf_pq
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
/*
7+
* NOTE: this file is generated by generate_ivf_pq.py
8+
*
9+
* Make changes there and run in this directory:
10+
*
11+
* > python generate_ivf_pq.py
12+
*
13+
*/
14+
15+
#include <cuvs/neighbors/ivf_pq.hpp>
16+
17+
#include "../ivf_pq_transform.cuh"
18+
19+
namespace cuvs::neighbors::ivf_pq {
20+
21+
#define CUVS_INST_IVF_PQ_TRANSFORM(T, IdxT) \
22+
void transform(raft::resources const& handle, \
23+
const cuvs::neighbors::ivf_pq::index<IdxT>& index, \
24+
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
25+
raft::device_vector_view<uint32_t, IdxT> output_labels, \
26+
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> output_dataset) \
27+
{ \
28+
cuvs::neighbors::ivf_pq::detail::transform( \
29+
handle, index, dataset, output_labels, output_dataset); \
30+
}
31+
CUVS_INST_IVF_PQ_TRANSFORM(uint8_t, int64_t);
32+
33+
#undef CUVS_INST_IVF_PQ_TRANSFORM
34+
35+
} // namespace cuvs::neighbors::ivf_pq

cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ void extend(raft::resources const& handle,
11221122
auto centers_view = raft::make_device_matrix_view<const float, internal_extents_t>(
11231123
cluster_centers.data(), n_clusters, index->dim());
11241124
cuvs::cluster::kmeans::balanced_params kmeans_params;
1125-
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>((int)index->metric());
1125+
kmeans_params.metric = index->metric();
11261126
cuvs::cluster::kmeans::predict(
11271127
handle, kmeans_params, batch_data_view, centers_view, batch_labels_view);
11281128
vec_batches.prefetch_next_batch();

0 commit comments

Comments
 (0)