Skip to content

Commit 94795b0

Browse files
authored
Use CCCL's mdspan implementation (rapidsai#1605)
Based on rapidsai/raft#2836 Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Bradley Dice (https://github.com/bdice) - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#1605
1 parent ba67db1 commit 94795b0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+331
-310
lines changed

c/src/core/detail/interop.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ inline MdspanType from_dlpack(DLManagedTensor* managed_tensor)
129129
"ndim mismatch between return mdspan and DLTensor");
130130

131131
// auto exts = typename MdspanType::extents_type{tensor.shape};
132-
std::array<int64_t, MdspanType::extents_type::rank()> shape{};
132+
cuda::std::array<int64_t, MdspanType::extents_type::rank()> shape{};
133133
for (int64_t i = 0; i < tensor.ndim; ++i) {
134134
shape[i] = tensor.shape[i];
135135
}

cpp/cmake/thirdparty/get_raft.cmake

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@ function(find_and_configure_raft)
1414
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
1515
"${multiValueArgs}" ${ARGN} )
1616

17+
# Set BUILD_SHARED_LIBS whenever building static dependencies
18+
if(PKG_BUILD_STATIC_DEPS)
19+
set(BUILD_SHARED_LIBS OFF)
20+
endif()
21+
22+
# Determine whether to clone raft locally
1723
if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "${rapids-cmake-checkout-tag}")
1824
message(STATUS "cuVS: RAFT pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.")
1925
set(CPM_DOWNLOAD_raft ON)
2026
elseif(PKG_BUILD_STATIC_DEPS AND (NOT CPM_raft_SOURCE))
2127
message(STATUS "cuVS: Cloning raft locally to build static libraries.")
2228
set(CPM_DOWNLOAD_raft ON)
23-
set(BUILD_SHARED_LIBS OFF)
2429
endif()
2530

2631
set(RAFT_COMPONENTS "")

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ void fit(raft::resources const& handle,
181181
raft::device_matrix_view<const float, int> X,
182182
std::optional<raft::device_vector_view<const float, int>> sample_weight,
183183
raft::device_matrix_view<float, int> centroids,
184-
raft::host_scalar_view<float, int> inertia,
185-
raft::host_scalar_view<int, int> n_iter);
184+
raft::host_scalar_view<float> inertia,
185+
raft::host_scalar_view<int> n_iter);
186186

187187
/**
188188
* @brief Find clusters with k-means algorithm.
@@ -232,8 +232,8 @@ void fit(raft::resources const& handle,
232232
raft::device_matrix_view<const float, int64_t> X,
233233
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
234234
raft::device_matrix_view<float, int64_t> centroids,
235-
raft::host_scalar_view<float, int64_t> inertia,
236-
raft::host_scalar_view<int64_t, int64_t> n_iter);
235+
raft::host_scalar_view<float> inertia,
236+
raft::host_scalar_view<int64_t> n_iter);
237237

238238
/**
239239
* @brief Find clusters with k-means algorithm.
@@ -282,8 +282,8 @@ void fit(raft::resources const& handle,
282282
raft::device_matrix_view<const double, int> X,
283283
std::optional<raft::device_vector_view<const double, int>> sample_weight,
284284
raft::device_matrix_view<double, int> centroids,
285-
raft::host_scalar_view<double, int> inertia,
286-
raft::host_scalar_view<int, int> n_iter);
285+
raft::host_scalar_view<double> inertia,
286+
raft::host_scalar_view<int> n_iter);
287287

288288
/**
289289
* @brief Find clusters with k-means algorithm.
@@ -333,8 +333,8 @@ void fit(raft::resources const& handle,
333333
raft::device_matrix_view<const double, int64_t> X,
334334
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
335335
raft::device_matrix_view<double, int64_t> centroids,
336-
raft::host_scalar_view<double, int64_t> inertia,
337-
raft::host_scalar_view<int64_t, int64_t> n_iter);
336+
raft::host_scalar_view<double> inertia,
337+
raft::host_scalar_view<int64_t> n_iter);
338338

339339
/**
340340
* @brief Find clusters with k-means algorithm.
@@ -383,8 +383,8 @@ void fit(raft::resources const& handle,
383383
raft::device_matrix_view<const int8_t, int> X,
384384
std::optional<raft::device_vector_view<const int8_t, int>> sample_weight,
385385
raft::device_matrix_view<int8_t, int> centroids,
386-
raft::host_scalar_view<int8_t, int> inertia,
387-
raft::host_scalar_view<int, int> n_iter);
386+
raft::host_scalar_view<int8_t> inertia,
387+
raft::host_scalar_view<int> n_iter);
388388

389389
/**
390390
* @brief Find balanced clusters with k-means algorithm.
@@ -581,6 +581,15 @@ void predict(raft::resources const& handle,
581581
bool normalize_weight,
582582
raft::host_scalar_view<float> inertia);
583583

584+
void predict(raft::resources const& handle,
585+
const kmeans::params& params,
586+
raft::device_matrix_view<const float, int64_t> X,
587+
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
588+
raft::device_matrix_view<const float, int64_t> centroids,
589+
raft::device_vector_view<int64_t, int64_t> labels,
590+
bool normalize_weight,
591+
raft::host_scalar_view<float> inertia);
592+
584593
/**
585594
* @brief Predict the closest cluster each sample in X belongs to.
586595
*
@@ -632,10 +641,10 @@ void predict(raft::resources const& handle,
632641
*/
633642
void predict(raft::resources const& handle,
634643
const kmeans::params& params,
635-
raft::device_matrix_view<const float, int> X,
636-
std::optional<raft::device_vector_view<const float, int>> sample_weight,
637-
raft::device_matrix_view<const float, int> centroids,
638-
raft::device_vector_view<int64_t, int> labels,
644+
raft::device_matrix_view<const float, int64_t> X,
645+
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
646+
raft::device_matrix_view<const float, int64_t> centroids,
647+
raft::device_vector_view<int64_t, int64_t> labels,
639648
bool normalize_weight,
640649
raft::host_scalar_view<float> inertia);
641650

@@ -748,10 +757,10 @@ void predict(raft::resources const& handle,
748757
*/
749758
void predict(raft::resources const& handle,
750759
const kmeans::params& params,
751-
raft::device_matrix_view<const double, int> X,
752-
std::optional<raft::device_vector_view<const double, int>> sample_weight,
753-
raft::device_matrix_view<const double, int> centroids,
754-
raft::device_vector_view<int64_t, int> labels,
760+
raft::device_matrix_view<const double, int64_t> X,
761+
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
762+
raft::device_matrix_view<const double, int64_t> centroids,
763+
raft::device_vector_view<int64_t, int64_t> labels,
755764
bool normalize_weight,
756765
raft::host_scalar_view<double> inertia);
757766

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
248248
}
249249
// Something is wrong: have to make a copy and produce an owning dataset
250250
auto out_layout =
251-
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});
251+
raft::make_strided_layout(src.extents(), cuda::std::array<index_type, 2>{required_stride, 1});
252252
auto out_array =
253253
raft::make_device_matrix<value_type, index_type>(res, src.extent(0), required_stride);
254254

@@ -310,7 +310,7 @@ auto make_strided_dataset(
310310
const bool stride_matches = required_stride == src_stride;
311311

312312
auto out_layout =
313-
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});
313+
raft::make_strided_layout(src.extents(), cuda::std::array<index_type, 2>{required_stride, 1});
314314

315315
using out_mdarray_type = raft::device_matrix<value_type, index_type>;
316316
using out_layout_type = typename out_mdarray_type::layout_type;

cpp/include/cuvs/neighbors/ivf_pq.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,8 @@ constexpr typename list_spec<SizeT, IdxT>::list_extents list_spec<SizeT, IdxT>::
269269
{
270270
// how many elems of pq_dim fit into one kIndexGroupVecLen-byte chunk
271271
auto pq_chunk = (kIndexGroupVecLen * 8u) / pq_bits;
272-
return raft::make_extents<SizeT>(raft::div_rounding_up_safe<SizeT>(n_rows, kIndexGroupSize),
273-
raft::div_rounding_up_safe<SizeT>(pq_dim, pq_chunk),
274-
kIndexGroupSize,
275-
kIndexGroupVecLen);
272+
return list_extents{raft::div_rounding_up_safe<SizeT>(n_rows, kIndexGroupSize),
273+
raft::div_rounding_up_safe<SizeT>(pq_dim, pq_chunk)};
276274
}
277275

278276
template <typename IdxT, typename SizeT = uint32_t>
@@ -335,8 +333,8 @@ struct index : cuvs::neighbors::index {
335333
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
336334
"IdxT must be able to represent all values of uint32_t");
337335

338-
using pq_centers_extents = std::experimental::
339-
extents<uint32_t, raft::dynamic_extent, raft::dynamic_extent, raft::dynamic_extent>;
336+
using pq_centers_extents =
337+
raft::extents<uint32_t, raft::dynamic_extent, raft::dynamic_extent, raft::dynamic_extent>;
340338

341339
public:
342340
index(const index&) = delete;
@@ -2875,7 +2873,7 @@ void make_rotation_matrix(raft::resources const& res,
28752873
*/
28762874
void set_centers(raft::resources const& res,
28772875
index<int64_t>* index,
2878-
raft::device_matrix_view<const float, uint32_t> cluster_centers);
2876+
raft::device_matrix_view<const float, int64_t> cluster_centers);
28792877

28802878
/**
28812879
* @brief Public helper API for fetching a trained index's IVF centroids
@@ -2896,7 +2894,7 @@ void set_centers(raft::resources const& res,
28962894
*/
28972895
void extract_centers(raft::resources const& res,
28982896
const index<int64_t>& index,
2899-
raft::device_matrix_view<float, uint32_t, raft::row_major> cluster_centers);
2897+
raft::device_matrix_view<float, int64_t, raft::row_major> cluster_centers);
29002898

29012899
/** @copydoc extract_centers */
29022900
void extract_centers(raft::resources const& res,

cpp/include/cuvs/neighbors/knn_merge_parts.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ void knn_merge_parts(raft::resources const& res,
2727
raft::device_matrix_view<const int64_t, int64_t> inV,
2828
raft::device_matrix_view<float, int64_t> outK,
2929
raft::device_matrix_view<int64_t, int64_t> outV,
30-
raft::device_vector_view<int64_t> translations);
30+
raft::device_vector_view<int64_t, int64_t> translations);
3131
void knn_merge_parts(raft::resources const& res,
3232
raft::device_matrix_view<const float, int64_t> inK,
3333
raft::device_matrix_view<const uint32_t, int64_t> inV,
3434
raft::device_matrix_view<float, int64_t> outK,
3535
raft::device_matrix_view<uint32_t, int64_t> outV,
36-
raft::device_vector_view<uint32_t> translations);
36+
raft::device_vector_view<uint32_t, int64_t> translations);
3737
void knn_merge_parts(raft::resources const& res,
3838
raft::device_matrix_view<const float, int64_t> inK,
3939
raft::device_matrix_view<const int32_t, int64_t> inV,
4040
raft::device_matrix_view<float, int64_t> outK,
4141
raft::device_matrix_view<int32_t, int64_t> outV,
42-
raft::device_vector_view<int32_t> translations);
42+
raft::device_vector_view<int32_t, int64_t> translations);
4343
} // namespace cuvs::neighbors

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,9 +1117,9 @@ void kmeans_predict(raft::resources const& handle,
11171117
template <typename DataT, typename IndexT = int>
11181118
void kmeans_transform(raft::resources const& handle,
11191119
const cuvs::cluster::kmeans::params& pams,
1120-
raft::device_matrix_view<const DataT> X,
1121-
raft::device_matrix_view<const DataT> centroids,
1122-
raft::device_matrix_view<DataT> X_new)
1120+
raft::device_matrix_view<const DataT, IndexT> X,
1121+
raft::device_matrix_view<const DataT, IndexT> centroids,
1122+
raft::device_matrix_view<DataT, IndexT> X_new)
11231123
{
11241124
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("kmeans_transform");
11251125
raft::default_logger().set_level(pams.verbosity);

cpp/src/cluster/detail/kmeans_auto_find_k.cuh

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ void compute_dispersion(raft::resources const& handle,
2525
raft::device_matrix_view<const value_t, idx_t> X,
2626
cuvs::cluster::kmeans::params& params,
2727
raft::device_matrix_view<value_t, idx_t> centroids_view,
28-
raft::device_vector_view<idx_t> labels,
29-
raft::device_vector_view<idx_t> clusterSizes,
28+
raft::device_vector_view<idx_t, idx_t> labels,
29+
raft::device_vector_view<idx_t, idx_t> clusterSizes,
3030
rmm::device_uvector<char>& workspace,
3131
raft::host_vector_view<value_t> clusterDispertionView,
3232
raft::host_vector_view<value_t> resultsView,
@@ -109,12 +109,15 @@ void find_k(raft::resources const& handle,
109109

110110
auto centroids_view =
111111
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), left, d);
112+
auto labels_view = raft::make_device_vector_view<idx_t, idx_t>(labels.data_handle(), n);
113+
auto clusterSizes_view =
114+
raft::make_device_vector_view<idx_t, idx_t>(clusterSizes.data_handle(), kmax);
112115
compute_dispersion<value_t, idx_t>(handle,
113116
X,
114117
params,
115118
centroids_view,
116-
labels.view(),
117-
clusterSizes.view(),
119+
labels_view,
120+
clusterSizes_view,
118121
workspace,
119122
clusterDispertionView,
120123
resultsView,
@@ -133,8 +136,8 @@ void find_k(raft::resources const& handle,
133136
X,
134137
params,
135138
centroids_view,
136-
labels.view(),
137-
clusterSizes.view(),
139+
labels_view,
140+
clusterSizes_view,
138141
workspace,
139142
clusterDispertionView,
140143
resultsView,
@@ -159,8 +162,8 @@ void find_k(raft::resources const& handle,
159162
X,
160163
params,
161164
centroids_view,
162-
labels.view(),
163-
clusterSizes.view(),
165+
labels_view,
166+
clusterSizes_view,
164167
workspace,
165168
clusterDispertionView,
166169
resultsView,

cpp/src/cluster/detail/single_linkage.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace cuvs::cluster::agglomerative::detail {
4343
template <typename value_t = float,
4444
typename value_idx = int,
4545
typename nnz_t = size_t,
46-
typename Accessor = raft::device_accessor<std::experimental::default_accessor<value_t>>>
46+
typename Accessor = raft::device_accessor<cuda::std::default_accessor<value_t>>>
4747
void build_mr_linkage(
4848
raft::resources const& handle,
4949
raft::mdspan<const value_t, raft::matrix_extent<value_idx>, raft::row_major, Accessor> X,

cpp/src/cluster/kmeans_fit_double.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ void fit(raft::resources const& handle,
4444
raft::device_matrix_view<const double, int> X,
4545
std::optional<raft::device_vector_view<const double, int>> sample_weight,
4646
raft::device_matrix_view<double, int> centroids,
47-
raft::host_scalar_view<double, int> inertia,
48-
raft::host_scalar_view<int, int> n_iter)
47+
raft::host_scalar_view<double> inertia,
48+
raft::host_scalar_view<int> n_iter)
4949
{
5050
cuvs::cluster::kmeans::fit<double, int>(
5151
handle, params, X, sample_weight, centroids, inertia, n_iter);
@@ -56,8 +56,8 @@ void fit(raft::resources const& handle,
5656
raft::device_matrix_view<const double, int64_t> X,
5757
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
5858
raft::device_matrix_view<double, int64_t> centroids,
59-
raft::host_scalar_view<double, int64_t> inertia,
60-
raft::host_scalar_view<int64_t, int64_t> n_iter)
59+
raft::host_scalar_view<double> inertia,
60+
raft::host_scalar_view<int64_t> n_iter)
6161
{
6262
cuvs::cluster::kmeans::fit<double, int64_t>(
6363
handle, params, X, sample_weight, centroids, inertia, n_iter);

0 commit comments

Comments
 (0)