@@ -236,25 +236,17 @@ void KnnQuery::topk_impl(const Matrix &items, const Matrix &query, int k,
236236 }
237237
238238 auto current_k = std::min (k, static_cast <int >(temp_distances.cols ));
239- rmm::device_uvector<float > best_distances (temp_distances.rows * current_k,
240- stream, mr.get ());
241- rmm::device_uvector<int > best_indices (temp_distances.rows * current_k,
242- stream, mr.get ());
243239
244240 auto distance_view = raft::make_device_matrix_view<const float , int64_t >(
245241 temp_distances, temp_distances.rows , temp_distances.cols );
246242
247243 raft::matrix::select_k<float , int >(
248244 handle, distance_view, std::nullopt ,
249245 raft::make_device_matrix_view<float , int64_t >(
250- best_distances. data () , temp_distances.rows , current_k),
246+ distances + start * k , temp_distances.rows , current_k),
251247 raft::make_device_matrix_view<int , int64_t >(
252- best_indices.data (), temp_distances.rows , current_k),
253- false );
254-
255- // raft::select_k doesn't sort inputs - so we have to do it here
256- argsort (best_indices.data (), best_distances.data (), temp_distances.rows ,
257- current_k, indices + start * k, distances + start * k);
248+ indices + start * k, temp_distances.rows , current_k),
249+ false , true );
258250 // TODO: callback per batch (show progress etc)
259251 }
260252
@@ -271,44 +263,6 @@ void KnnQuery::topk_impl(const Matrix &items, const Matrix &query, int k,
271263 }
272264}
273265
274- void KnnQuery::argsort (const int *input_indices, const float *input_distances,
275- int rows, int cols, int *indices, float *distances) {
276- rmm::cuda_stream_view stream;
277- auto segment_offsets = thrust::make_transform_iterator (
278- thrust::make_counting_iterator<int >(0 ),
279- [=] __device__ (int i) { return i * cols; });
280-
281- void *temp_mem = NULL ;
282- size_t temp_size = 0 ;
283-
284- // sort the values.
285- if (rows > 1 ) {
286- auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
287- NULL , temp_size, input_distances, distances, input_indices, indices,
288- rows * cols, rows, segment_offsets, segment_offsets + 1 , 0 ,
289- sizeof (float ) * 8 , stream);
290- CHECK_CUDA (err);
291- temp_mem = mr->allocate (temp_size, stream);
292- err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
293- temp_mem, temp_size, input_distances, distances, input_indices, indices,
294- rows * cols, rows, segment_offsets, segment_offsets + 1 , 0 ,
295- sizeof (float ) * 8 , stream);
296- CHECK_CUDA (err);
297- } else {
298- size_t temp_size = 0 ;
299- auto err = cub::DeviceRadixSort::SortPairsDescending (
300- NULL , temp_size, input_distances, distances, input_indices, indices,
301- cols, 0 , sizeof (float ) * 8 , stream);
302- CHECK_CUDA (err);
303- temp_mem = mr->allocate (temp_size, stream);
304- err = cub::DeviceRadixSort::SortPairsDescending (
305- temp_mem, temp_size, input_distances, distances, input_indices, indices,
306- cols, 0 , sizeof (float ) * 8 , stream);
307- CHECK_CUDA (err);
308- }
309- mr->deallocate (temp_mem, temp_size, stream);
310- }
311-
312266KnnQuery::~KnnQuery () {}
313267
314268} // namespace gpu
0 commit comments