Skip to content

Commit cc67077

Browse files
committed
Add knn dispatch, CPU and CUDA kernels
CPU kernel uses nanoflann KD-tree for efficient nearest neighbor search. CUDA kernel uses brute-force pairwise distance with insertion sort top-k, supporting cosine distance. Python wrapper and tests included.
1 parent 3337ec4 commit cc67077

File tree

6 files changed

+596
-0
lines changed

6 files changed

+596
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "../knn.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
7+
#include "utils/nanoflann.hpp"
8+
9+
namespace pyg {
10+
namespace ops {
11+
12+
namespace {
13+
14+
at::Tensor knn_kernel(const at::Tensor& x,
15+
const at::Tensor& y,
16+
const std::optional<at::Tensor>& ptr_x,
17+
const std::optional<at::Tensor>& ptr_y,
18+
int64_t k,
19+
bool cosine,
20+
int64_t num_workers) {
21+
TORCH_CHECK(!cosine, "`cosine` argument not supported on CPU");
22+
23+
std::vector<size_t> out_vec;
24+
25+
AT_DISPATCH_ALL_TYPES_AND2(
26+
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(),
27+
"knn_cpu", [&] {
28+
auto x_data = x.data_ptr<scalar_t>();
29+
auto y_data = y.data_ptr<scalar_t>();
30+
typedef std::vector<std::vector<scalar_t>> vec_t;
31+
32+
if (!ptr_x.has_value()) {
33+
vec_t pts(x.size(0));
34+
for (int64_t i = 0; i < x.size(0); i++) {
35+
pts[i].resize(x.size(1));
36+
for (int64_t j = 0; j < x.size(1); j++) {
37+
pts[i][j] = x_data[i * x.size(1) + j];
38+
}
39+
}
40+
41+
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
42+
my_kd_tree_t mat_index(x.size(1), pts, 10);
43+
44+
std::vector<size_t> ret_index(k);
45+
std::vector<scalar_t> out_dist_sqr(k);
46+
for (int64_t i = 0; i < y.size(0); i++) {
47+
size_t num_matches = mat_index.index->knnSearch(
48+
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);
49+
for (size_t j = 0; j < num_matches; j++) {
50+
out_vec.push_back(ret_index[j]);
51+
out_vec.push_back(i);
52+
}
53+
}
54+
} else {
55+
auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
56+
auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();
57+
58+
for (int64_t b = 0; b < ptr_x.value().size(0) - 1; b++) {
59+
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
60+
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
61+
62+
if (x_start == x_end || y_start == y_end)
63+
continue;
64+
65+
vec_t pts(x_end - x_start);
66+
for (int64_t i = 0; i < x_end - x_start; i++) {
67+
pts[i].resize(x.size(1));
68+
for (int64_t j = 0; j < x.size(1); j++) {
69+
pts[i][j] = x_data[(i + x_start) * x.size(1) + j];
70+
}
71+
}
72+
73+
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
74+
my_kd_tree_t mat_index(x.size(1), pts, 10);
75+
76+
std::vector<size_t> ret_index(k);
77+
std::vector<scalar_t> out_dist_sqr(k);
78+
for (int64_t i = y_start; i < y_end; i++) {
79+
size_t num_matches = mat_index.index->knnSearch(
80+
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);
81+
for (size_t j = 0; j < num_matches; j++) {
82+
out_vec.push_back(x_start + ret_index[j]);
83+
out_vec.push_back(i);
84+
}
85+
}
86+
}
87+
}
88+
});
89+
90+
const int64_t size = out_vec.size() / 2;
91+
auto out =
92+
at::from_blob(out_vec.data(), {size, 2}, x.options().dtype(at::kLong));
93+
return out.t().index_select(0, at::tensor({1, 0})).clone();
94+
}
95+
96+
} // namespace
97+
98+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
99+
m.impl(TORCH_SELECTIVE_NAME("pyg::knn"), TORCH_FN(knn_kernel));
100+
}
101+
102+
} // namespace ops
103+
} // namespace pyg
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#include "../knn.h"
2+
#include "utils.cuh"
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <torch/library.h>
7+
8+
namespace pyg {
9+
namespace ops {
10+
11+
namespace {
12+
13+
#define KNN_THREADS 256
14+
15+
template <typename scalar_t>
16+
struct Cosine {
17+
static inline __device__ scalar_t dot(const scalar_t* a,
18+
const scalar_t* b,
19+
int64_t n_a,
20+
int64_t n_b,
21+
int64_t size) {
22+
scalar_t result = 0;
23+
for (int64_t i = 0; i < size; i++) {
24+
result += a[n_a * size + i] * b[n_b * size + i];
25+
}
26+
return result;
27+
}
28+
29+
static inline __device__ scalar_t norm(const scalar_t* a,
30+
int64_t n_a,
31+
int64_t size) {
32+
scalar_t result = 0;
33+
for (int64_t i = 0; i < size; i++) {
34+
result += a[n_a * size + i] * a[n_a * size + i];
35+
}
36+
return sqrt(result);
37+
}
38+
};
39+
40+
template <typename scalar_t>
41+
__global__ void knn_cuda_kernel(const scalar_t* __restrict__ x,
42+
const scalar_t* __restrict__ y,
43+
const int64_t* __restrict__ ptr_x,
44+
const int64_t* __restrict__ ptr_y,
45+
int64_t* __restrict__ row,
46+
int64_t* __restrict__ col,
47+
const int64_t k,
48+
const int64_t n,
49+
const int64_t m,
50+
const int64_t dim,
51+
const int64_t num_examples,
52+
const bool cosine) {
53+
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
54+
if (n_y >= m)
55+
return;
56+
57+
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
58+
59+
scalar_t best_dist[100];
60+
int64_t best_idx[100];
61+
62+
for (int e = 0; e < k; e++) {
63+
best_dist[e] = (scalar_t)1e10;
64+
best_idx[e] = -1;
65+
}
66+
67+
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
68+
scalar_t tmp_dist = 0;
69+
70+
if (cosine) {
71+
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
72+
(Cosine<scalar_t>::norm(x, n_x, dim) *
73+
Cosine<scalar_t>::norm(y, n_y, dim));
74+
tmp_dist = (scalar_t)1. - tmp_dist;
75+
} else {
76+
for (int64_t d = 0; d < dim; d++) {
77+
scalar_t diff = x[n_x * dim + d] - y[n_y * dim + d];
78+
tmp_dist += diff * diff;
79+
}
80+
}
81+
82+
for (int64_t e1 = 0; e1 < k; e1++) {
83+
if (scalar_gt(best_dist[e1], tmp_dist)) {
84+
for (int64_t e2 = k - 1; e2 > e1; e2--) {
85+
best_dist[e2] = best_dist[e2 - 1];
86+
best_idx[e2] = best_idx[e2 - 1];
87+
}
88+
best_dist[e1] = tmp_dist;
89+
best_idx[e1] = n_x;
90+
break;
91+
}
92+
}
93+
}
94+
95+
for (int64_t e = 0; e < k; e++) {
96+
row[n_y * k + e] = n_y;
97+
col[n_y * k + e] = best_idx[e];
98+
}
99+
}
100+
101+
at::Tensor knn_cuda(const at::Tensor& x,
102+
const at::Tensor& y,
103+
const std::optional<at::Tensor>& ptr_x,
104+
const std::optional<at::Tensor>& ptr_y,
105+
int64_t k,
106+
bool cosine,
107+
int64_t num_workers) {
108+
TORCH_CHECK(x.is_cuda() && y.is_cuda(), "Inputs must be CUDA tensors");
109+
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(),
110+
"Inputs must be contiguous");
111+
TORCH_CHECK(k <= 100, "`k` must be <= 100");
112+
113+
std::optional<at::Tensor> ptr_x_v = ptr_x;
114+
std::optional<at::Tensor> ptr_y_v = ptr_y;
115+
116+
if (!ptr_x_v.has_value())
117+
ptr_x_v =
118+
at::arange(0, x.size(0) + 1, x.size(0), x.options().dtype(at::kLong));
119+
if (!ptr_y_v.has_value())
120+
ptr_y_v =
121+
at::arange(0, y.size(0) + 1, y.size(0), y.options().dtype(at::kLong));
122+
123+
TORCH_CHECK(ptr_x_v.value().numel() == ptr_y_v.value().numel(),
124+
"ptr_x and ptr_y must have the same number of elements");
125+
126+
auto row = at::empty({y.size(0) * k}, ptr_y_v.value().options());
127+
auto col = at::full({y.size(0) * k}, -1, ptr_y_v.value().options());
128+
129+
dim3 BLOCKS((y.size(0) + KNN_THREADS - 1) / KNN_THREADS);
130+
auto stream = at::cuda::getCurrentCUDAStream();
131+
AT_DISPATCH_FLOATING_TYPES_AND(
132+
at::ScalarType::Half, x.scalar_type(), "knn_cuda", [&] {
133+
knn_cuda_kernel<scalar_t><<<BLOCKS, KNN_THREADS, 0, stream>>>(
134+
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
135+
ptr_x_v.value().data_ptr<int64_t>(),
136+
ptr_y_v.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
137+
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
138+
ptr_x_v.value().numel() - 1, cosine);
139+
C10_CUDA_KERNEL_LAUNCH_CHECK();
140+
});
141+
142+
auto mask = col != -1;
143+
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
144+
}
145+
146+
} // namespace
147+
148+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
149+
m.impl(TORCH_SELECTIVE_NAME("pyg::knn"), TORCH_FN(knn_cuda));
150+
}
151+
152+
} // namespace ops
153+
} // namespace pyg

pyg_lib/csrc/ops/knn.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "knn.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor knn(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y,
13+
int64_t k,
14+
bool cosine,
15+
int64_t num_workers) {
16+
at::TensorArg x_arg{x, "x", 0};
17+
at::TensorArg y_arg{y, "y", 1};
18+
at::CheckedFrom c{"knn"};
19+
20+
at::checkAllDefined(c, {x_arg, y_arg});
21+
at::checkDim(c, x_arg, 2);
22+
at::checkDim(c, y_arg, 2);
23+
24+
TORCH_CHECK(x.size(1) == y.size(1), "x and y must have the same feature dim");
25+
TORCH_CHECK(k > 0, "k must be positive");
26+
27+
auto x_c = x.contiguous();
28+
auto y_c = y.contiguous();
29+
30+
static auto op = c10::Dispatcher::singleton()
31+
.findSchemaOrThrow("pyg::knn", "")
32+
.typed<decltype(knn)>();
33+
return op.call(x_c, y_c, ptr_x, ptr_y, k, cosine, num_workers);
34+
}
35+
36+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
37+
m.def(
38+
TORCH_SELECTIVE_SCHEMA("pyg::knn(Tensor x, Tensor y, Tensor? ptr_x=None, "
39+
"Tensor? ptr_y=None, int k=1, bool cosine=False, "
40+
"int num_workers=1) -> Tensor"));
41+
}
42+
43+
} // namespace ops
44+
} // namespace pyg

pyg_lib/csrc/ops/knn.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor knn(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y,
13+
int64_t k,
14+
bool cosine,
15+
int64_t num_workers);
16+
17+
} // namespace ops
18+
} // namespace pyg

0 commit comments

Comments
 (0)