@@ -181,8 +181,8 @@ void fit(raft::resources const& handle,
181181 raft::device_matrix_view<const float , int > X,
182182 std::optional<raft::device_vector_view<const float , int >> sample_weight,
183183 raft::device_matrix_view<float , int > centroids,
184- raft::host_scalar_view<float , int > inertia,
185- raft::host_scalar_view<int , int > n_iter);
184+ raft::host_scalar_view<float > inertia,
185+ raft::host_scalar_view<int > n_iter);
186186
187187/* *
188188 * @brief Find clusters with k-means algorithm.
@@ -232,8 +232,8 @@ void fit(raft::resources const& handle,
232232 raft::device_matrix_view<const float , int64_t > X,
233233 std::optional<raft::device_vector_view<const float , int64_t >> sample_weight,
234234 raft::device_matrix_view<float , int64_t > centroids,
235- raft::host_scalar_view<float , int64_t > inertia,
236- raft::host_scalar_view<int64_t , int64_t > n_iter);
235+ raft::host_scalar_view<float > inertia,
236+ raft::host_scalar_view<int64_t > n_iter);
237237
238238/* *
239239 * @brief Find clusters with k-means algorithm.
@@ -282,8 +282,8 @@ void fit(raft::resources const& handle,
282282 raft::device_matrix_view<const double , int > X,
283283 std::optional<raft::device_vector_view<const double , int >> sample_weight,
284284 raft::device_matrix_view<double , int > centroids,
285- raft::host_scalar_view<double , int > inertia,
286- raft::host_scalar_view<int , int > n_iter);
285+ raft::host_scalar_view<double > inertia,
286+ raft::host_scalar_view<int > n_iter);
287287
288288/* *
289289 * @brief Find clusters with k-means algorithm.
@@ -333,8 +333,8 @@ void fit(raft::resources const& handle,
333333 raft::device_matrix_view<const double , int64_t > X,
334334 std::optional<raft::device_vector_view<const double , int64_t >> sample_weight,
335335 raft::device_matrix_view<double , int64_t > centroids,
336- raft::host_scalar_view<double , int64_t > inertia,
337- raft::host_scalar_view<int64_t , int64_t > n_iter);
336+ raft::host_scalar_view<double > inertia,
337+ raft::host_scalar_view<int64_t > n_iter);
338338
339339/* *
340340 * @brief Find clusters with k-means algorithm.
@@ -383,8 +383,8 @@ void fit(raft::resources const& handle,
383383 raft::device_matrix_view<const int8_t , int > X,
384384 std::optional<raft::device_vector_view<const int8_t , int >> sample_weight,
385385 raft::device_matrix_view<int8_t , int > centroids,
386- raft::host_scalar_view<int8_t , int > inertia,
387- raft::host_scalar_view<int , int > n_iter);
386+ raft::host_scalar_view<int8_t > inertia,
387+ raft::host_scalar_view<int > n_iter);
388388
389389/* *
390390 * @brief Find balanced clusters with k-means algorithm.
@@ -581,6 +581,15 @@ void predict(raft::resources const& handle,
581581 bool normalize_weight,
582582 raft::host_scalar_view<float > inertia);
583583
584+ void predict (raft::resources const & handle,
585+ const kmeans::params& params,
586+ raft::device_matrix_view<const float , int64_t > X,
587+ std::optional<raft::device_vector_view<const float , int64_t >> sample_weight,
588+ raft::device_matrix_view<const float , int64_t > centroids,
589+ raft::device_vector_view<int64_t , int64_t > labels,
590+ bool normalize_weight,
591+ raft::host_scalar_view<float > inertia);
592+
584593/* *
585594 * @brief Predict the closest cluster each sample in X belongs to.
586595 *
@@ -632,10 +641,10 @@ void predict(raft::resources const& handle,
632641 */
633642void predict (raft::resources const & handle,
634643 const kmeans::params& params,
635- raft::device_matrix_view<const float , int > X,
636- std::optional<raft::device_vector_view<const float , int >> sample_weight,
637- raft::device_matrix_view<const float , int > centroids,
638- raft::device_vector_view<int64_t , int > labels,
644+ raft::device_matrix_view<const float , int64_t > X,
645+ std::optional<raft::device_vector_view<const float , int64_t >> sample_weight,
646+ raft::device_matrix_view<const float , int64_t > centroids,
647+ raft::device_vector_view<int64_t , int64_t > labels,
639648 bool normalize_weight,
640649 raft::host_scalar_view<float > inertia);
641650
@@ -748,10 +757,10 @@ void predict(raft::resources const& handle,
748757 */
749758void predict (raft::resources const & handle,
750759 const kmeans::params& params,
751- raft::device_matrix_view<const double , int > X,
752- std::optional<raft::device_vector_view<const double , int >> sample_weight,
753- raft::device_matrix_view<const double , int > centroids,
754- raft::device_vector_view<int64_t , int > labels,
760+ raft::device_matrix_view<const double , int64_t > X,
761+ std::optional<raft::device_vector_view<const double , int64_t >> sample_weight,
762+ raft::device_matrix_view<const double , int64_t > centroids,
763+ raft::device_vector_view<int64_t , int64_t > labels,
755764 bool normalize_weight,
756765 raft::host_scalar_view<double > inertia);
757766
0 commit comments