Skip to content

Commit 4283257

Browse files
authored
use raft 23.08 (#681)
Update to use the latest raft version. RAFT 23.08 includes a sort option as part of the select_k method rapidsai/raft#1615 which means we don't have to sort the output ourselves.
1 parent 0db47c7 commit 4283257

File tree

3 files changed

+6
-55
lines changed

3 files changed

+6
-55
lines changed

implicit/gpu/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ else()
1414
add_cython_target(_cuda CXX)
1515

1616
# use rapids-cmake to install dependencies
17-
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-23.06/RAPIDS.cmake
17+
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-23.08/RAPIDS.cmake
1818
${CMAKE_BINARY_DIR}/RAPIDS.cmake)
1919
include(${CMAKE_BINARY_DIR}/RAPIDS.cmake)
2020
include(rapids-cmake)
@@ -57,10 +57,10 @@ else()
5757
# get raft
5858
# note: we're using RAFT in header only mode right now - mainly to reduce binary
5959
# size of the compiled wheels
60-
rapids_cpm_find(raft 23.06
60+
rapids_cpm_find(raft 23.08
6161
CPM_ARGS
6262
GIT_REPOSITORY https://github.com/rapidsai/raft.git
63-
GIT_TAG branch-23.06
63+
GIT_TAG branch-23.08
6464
DOWNLOAD_ONLY YES
6565
)
6666
include_directories(${raft_SOURCE_DIR}/cpp/include)

implicit/gpu/knn.cu

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -236,25 +236,17 @@ void KnnQuery::topk_impl(const Matrix &items, const Matrix &query, int k,
236236
}
237237

238238
auto current_k = std::min(k, static_cast<int>(temp_distances.cols));
239-
rmm::device_uvector<float> best_distances(temp_distances.rows * current_k,
240-
stream, mr.get());
241-
rmm::device_uvector<int> best_indices(temp_distances.rows * current_k,
242-
stream, mr.get());
243239

244240
auto distance_view = raft::make_device_matrix_view<const float, int64_t>(
245241
temp_distances, temp_distances.rows, temp_distances.cols);
246242

247243
raft::matrix::select_k<float, int>(
248244
handle, distance_view, std::nullopt,
249245
raft::make_device_matrix_view<float, int64_t>(
250-
best_distances.data(), temp_distances.rows, current_k),
246+
distances + start * k, temp_distances.rows, current_k),
251247
raft::make_device_matrix_view<int, int64_t>(
252-
best_indices.data(), temp_distances.rows, current_k),
253-
false);
254-
255-
// raft::select_k doesn't sort inputs - so we have to do it here
256-
argsort(best_indices.data(), best_distances.data(), temp_distances.rows,
257-
current_k, indices + start * k, distances + start * k);
248+
indices + start * k, temp_distances.rows, current_k),
249+
false, true);
258250
// TODO: callback per batch (show progress etc)
259251
}
260252

@@ -271,44 +263,6 @@ void KnnQuery::topk_impl(const Matrix &items, const Matrix &query, int k,
271263
}
272264
}
273265

274-
void KnnQuery::argsort(const int *input_indices, const float *input_distances,
275-
int rows, int cols, int *indices, float *distances) {
276-
rmm::cuda_stream_view stream;
277-
auto segment_offsets = thrust::make_transform_iterator(
278-
thrust::make_counting_iterator<int>(0),
279-
[=] __device__(int i) { return i * cols; });
280-
281-
void *temp_mem = NULL;
282-
size_t temp_size = 0;
283-
284-
// sort the values.
285-
if (rows > 1) {
286-
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
287-
NULL, temp_size, input_distances, distances, input_indices, indices,
288-
rows * cols, rows, segment_offsets, segment_offsets + 1, 0,
289-
sizeof(float) * 8, stream);
290-
CHECK_CUDA(err);
291-
temp_mem = mr->allocate(temp_size, stream);
292-
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
293-
temp_mem, temp_size, input_distances, distances, input_indices, indices,
294-
rows * cols, rows, segment_offsets, segment_offsets + 1, 0,
295-
sizeof(float) * 8, stream);
296-
CHECK_CUDA(err);
297-
} else {
298-
size_t temp_size = 0;
299-
auto err = cub::DeviceRadixSort::SortPairsDescending(
300-
NULL, temp_size, input_distances, distances, input_indices, indices,
301-
cols, 0, sizeof(float) * 8, stream);
302-
CHECK_CUDA(err);
303-
temp_mem = mr->allocate(temp_size, stream);
304-
err = cub::DeviceRadixSort::SortPairsDescending(
305-
temp_mem, temp_size, input_distances, distances, input_indices, indices,
306-
cols, 0, sizeof(float) * 8, stream);
307-
CHECK_CUDA(err);
308-
}
309-
mr->deallocate(temp_mem, temp_size, stream);
310-
}
311-
312266
KnnQuery::~KnnQuery() {}
313267

314268
} // namespace gpu

implicit/gpu/knn.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ class KnnQuery {
2525
const COOMatrix *query_filter = NULL,
2626
Vector<int> *item_filter = NULL);
2727

28-
void argsort(const int *input_indices, const float *input_distances, int rows,
29-
int cols, int *indices, float *distances);
30-
3128
protected:
3229
std::unique_ptr<rmm::mr::device_memory_resource> mr;
3330
raft::resources handle;

0 commit comments

Comments
 (0)