Skip to content

Commit d1c8562

Browse files
committed
Add radius dispatch, CPU and CUDA kernels
CPU kernel uses nanoflann KD-tree radiusSearch. CUDA kernel uses brute-force pairwise squared Euclidean distance. Supports max_num_neighbors cap and ignore_same_index. Tests included.
1 parent a755d57 commit d1c8562

File tree

5 files changed

+364
-0
lines changed

5 files changed

+364
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include "../radius.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
7+
#include "utils/nanoflann.hpp"
8+
9+
namespace pyg {
10+
namespace ops {
11+
12+
namespace {
13+
14+
at::Tensor radius_kernel(const at::Tensor& x,
15+
const at::Tensor& y,
16+
const std::optional<at::Tensor>& ptr_x,
17+
const std::optional<at::Tensor>& ptr_y,
18+
double r,
19+
int64_t max_num_neighbors,
20+
int64_t num_workers,
21+
bool ignore_same_index) {
22+
std::vector<size_t> out_vec;
23+
24+
AT_DISPATCH_ALL_TYPES_AND2(
25+
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(),
26+
"radius_cpu", [&] {
27+
auto x_data = x.data_ptr<scalar_t>();
28+
auto y_data = y.data_ptr<scalar_t>();
29+
typedef std::vector<std::vector<scalar_t>> vec_t;
30+
nanoflann::SearchParams params;
31+
params.sorted = false;
32+
33+
if (!ptr_x.has_value()) {
34+
vec_t pts(x.size(0));
35+
for (int64_t i = 0; i < x.size(0); i++) {
36+
pts[i].resize(x.size(1));
37+
for (int64_t j = 0; j < x.size(1); j++) {
38+
pts[i][j] = x_data[i * x.size(1) + j];
39+
}
40+
}
41+
42+
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
43+
my_kd_tree_t mat_index(x.size(1), pts, 10);
44+
45+
for (int64_t i = 0; i < y.size(0); i++) {
46+
std::vector<std::pair<size_t, scalar_t>> ret_matches;
47+
size_t num_matches = mat_index.index->radiusSearch(
48+
y_data + i * y.size(1), r * r, ret_matches, params);
49+
50+
for (size_t j = 0, count = 0;
51+
j < num_matches && count < (size_t)max_num_neighbors; j++) {
52+
if (!ignore_same_index ||
53+
ret_matches[j].first != static_cast<size_t>(i)) {
54+
out_vec.push_back(ret_matches[j].first);
55+
out_vec.push_back(i);
56+
count++;
57+
}
58+
}
59+
}
60+
} else {
61+
auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
62+
auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();
63+
64+
for (int64_t b = 0; b < ptr_x.value().size(0) - 1; b++) {
65+
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
66+
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
67+
68+
if (x_start == x_end || y_start == y_end)
69+
continue;
70+
71+
vec_t pts(x_end - x_start);
72+
for (int64_t i = 0; i < x_end - x_start; i++) {
73+
pts[i].resize(x.size(1));
74+
for (int64_t j = 0; j < x.size(1); j++) {
75+
pts[i][j] = x_data[(i + x_start) * x.size(1) + j];
76+
}
77+
}
78+
79+
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
80+
my_kd_tree_t mat_index(x.size(1), pts, 10);
81+
82+
for (int64_t i = y_start; i < y_end; i++) {
83+
std::vector<std::pair<size_t, scalar_t>> ret_matches;
84+
size_t num_matches = mat_index.index->radiusSearch(
85+
y_data + i * y.size(1), r * r, ret_matches, params);
86+
87+
for (size_t j = 0, count = 0;
88+
j < num_matches && count < (size_t)max_num_neighbors; j++) {
89+
if (!ignore_same_index ||
90+
x_start + static_cast<int64_t>(ret_matches[j].first) != i) {
91+
out_vec.push_back(x_start + ret_matches[j].first);
92+
out_vec.push_back(i);
93+
count++;
94+
}
95+
}
96+
}
97+
}
98+
}
99+
});
100+
101+
const int64_t size = out_vec.size() / 2;
102+
auto out =
103+
at::from_blob(out_vec.data(), {size, 2}, x.options().dtype(at::kLong));
104+
return out.t().index_select(0, at::tensor({1, 0})).clone();
105+
}
106+
107+
} // namespace
108+
109+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
110+
m.impl(TORCH_SELECTIVE_NAME("pyg::radius"), TORCH_FN(radius_kernel));
111+
}
112+
113+
} // namespace ops
114+
} // namespace pyg
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include "../radius.h"
2+
#include "utils.cuh"
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <torch/library.h>
7+
8+
namespace pyg {
9+
namespace ops {
10+
11+
namespace {
12+
13+
#define RADIUS_THREADS 256
14+
15+
template <typename scalar_t>
16+
__global__ void radius_cuda_kernel(const scalar_t* __restrict__ x,
17+
const scalar_t* __restrict__ y,
18+
const int64_t* __restrict__ ptr_x,
19+
const int64_t* __restrict__ ptr_y,
20+
int64_t* __restrict__ row,
21+
int64_t* __restrict__ col,
22+
const scalar_t r,
23+
const int64_t n,
24+
const int64_t m,
25+
const int64_t dim,
26+
const int64_t num_examples,
27+
const int64_t max_num_neighbors,
28+
const bool ignore_same_index) {
29+
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
30+
if (n_y >= m)
31+
return;
32+
33+
int64_t count = 0;
34+
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
35+
36+
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
37+
scalar_t dist = 0;
38+
for (int64_t d = 0; d < dim; d++) {
39+
scalar_t diff = x[n_x * dim + d] - y[n_y * dim + d];
40+
dist += diff * diff;
41+
}
42+
43+
if (scalar_lt(dist, r) && !(ignore_same_index && n_y == n_x)) {
44+
row[n_y * max_num_neighbors + count] = n_y;
45+
col[n_y * max_num_neighbors + count] = n_x;
46+
count++;
47+
}
48+
49+
if (count >= max_num_neighbors)
50+
break;
51+
}
52+
}
53+
54+
at::Tensor radius_cuda(const at::Tensor& x,
55+
const at::Tensor& y,
56+
const std::optional<at::Tensor>& ptr_x,
57+
const std::optional<at::Tensor>& ptr_y,
58+
double r,
59+
int64_t max_num_neighbors,
60+
int64_t num_workers,
61+
bool ignore_same_index) {
62+
TORCH_CHECK(x.is_cuda() && y.is_cuda(), "Inputs must be CUDA tensors");
63+
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(),
64+
"Inputs must be contiguous");
65+
66+
std::optional<at::Tensor> ptr_x_v = ptr_x;
67+
std::optional<at::Tensor> ptr_y_v = ptr_y;
68+
69+
if (!ptr_x_v.has_value())
70+
ptr_x_v =
71+
at::arange(0, x.size(0) + 1, x.size(0), x.options().dtype(at::kLong));
72+
if (!ptr_y_v.has_value())
73+
ptr_y_v =
74+
at::arange(0, y.size(0) + 1, y.size(0), y.options().dtype(at::kLong));
75+
76+
TORCH_CHECK(ptr_x_v.value().numel() == ptr_y_v.value().numel(),
77+
"ptr_x and ptr_y must have the same number of elements");
78+
79+
auto row =
80+
at::full({y.size(0) * max_num_neighbors}, -1, ptr_y_v.value().options());
81+
auto col =
82+
at::full({y.size(0) * max_num_neighbors}, -1, ptr_y_v.value().options());
83+
84+
dim3 BLOCKS((y.size(0) + RADIUS_THREADS - 1) / RADIUS_THREADS);
85+
auto stream = at::cuda::getCurrentCUDAStream();
86+
AT_DISPATCH_FLOATING_TYPES_AND2(
87+
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(),
88+
"radius_cuda", [&] {
89+
radius_cuda_kernel<scalar_t><<<BLOCKS, RADIUS_THREADS, 0, stream>>>(
90+
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
91+
ptr_x_v.value().data_ptr<int64_t>(),
92+
ptr_y_v.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
93+
col.data_ptr<int64_t>(), (scalar_t)(r * r), x.size(0), y.size(0),
94+
x.size(1), ptr_x_v.value().numel() - 1, max_num_neighbors,
95+
ignore_same_index);
96+
C10_CUDA_KERNEL_LAUNCH_CHECK();
97+
});
98+
99+
auto mask = row != -1;
100+
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
101+
}
102+
103+
} // namespace
104+
105+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
106+
m.impl(TORCH_SELECTIVE_NAME("pyg::radius"), TORCH_FN(radius_cuda));
107+
}
108+
109+
} // namespace ops
110+
} // namespace pyg

pyg_lib/csrc/ops/radius.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "radius.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor radius(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y,
13+
double r,
14+
int64_t max_num_neighbors,
15+
int64_t num_workers,
16+
bool ignore_same_index) {
17+
at::TensorArg x_arg{x, "x", 0};
18+
at::TensorArg y_arg{y, "y", 1};
19+
at::CheckedFrom c{"radius"};
20+
21+
at::checkAllDefined(c, {x_arg, y_arg});
22+
at::checkDim(c, x_arg, 2);
23+
at::checkDim(c, y_arg, 2);
24+
TORCH_CHECK(x.size(1) == y.size(1), "x and y must have the same feature dim");
25+
TORCH_CHECK(r > 0, "r must be positive");
26+
27+
auto x_c = x.contiguous();
28+
auto y_c = y.contiguous();
29+
30+
static auto op = c10::Dispatcher::singleton()
31+
.findSchemaOrThrow("pyg::radius", "")
32+
.typed<decltype(radius)>();
33+
return op.call(x_c, y_c, ptr_x, ptr_y, r, max_num_neighbors, num_workers,
34+
ignore_same_index);
35+
}
36+
37+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
38+
m.def(TORCH_SELECTIVE_SCHEMA(
39+
"pyg::radius(Tensor x, Tensor y, Tensor? ptr_x=None, "
40+
"Tensor? ptr_y=None, float r=1.0, int max_num_neighbors=32, "
41+
"int num_workers=1, bool ignore_same_index=False) -> Tensor"));
42+
}
43+
44+
} // namespace ops
45+
} // namespace pyg

pyg_lib/csrc/ops/radius.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor radius(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y,
13+
double r,
14+
int64_t max_num_neighbors,
15+
int64_t num_workers,
16+
bool ignore_same_index);
17+
18+
} // namespace ops
19+
} // namespace pyg

test/ops/test_radius.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
import torch
3+
4+
import pyg_lib
5+
from pyg_lib.testing import withCUDA
6+
7+
8+
@withCUDA
9+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
10+
def test_radius_basic(dtype: torch.dtype, device: torch.device) -> None:
11+
x = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [10.0, 0.0]],
12+
dtype=dtype, device=device)
13+
y = torch.tensor([[0.5, 0.0]], dtype=dtype, device=device)
14+
15+
out = pyg_lib.ops.radius(x, y, r=1.5)
16+
assert out.shape[0] == 2
17+
18+
# Points at distance 0.5 and 0.5 should be found (x[0] and x[1])
19+
refs = out[1].sort()[0]
20+
assert refs.tolist() == [0, 1]
21+
assert (out[0] == 0).all()
22+
23+
24+
@withCUDA
25+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
26+
def test_radius_correctness(dtype: torch.dtype, device: torch.device) -> None:
27+
x = torch.randn(30, 3, dtype=dtype, device=device)
28+
y = torch.randn(10, 3, dtype=dtype, device=device)
29+
r = 1.5
30+
31+
out = pyg_lib.ops.radius(x, y, r=r, max_num_neighbors=100)
32+
33+
# All returned edges should be within radius
34+
dists = torch.cdist(y.float(), x.float())
35+
for idx in range(out.shape[1]):
36+
qi, ri = out[0, idx].item(), out[1, idx].item()
37+
assert dists[qi, ri] <= r + 1e-5
38+
39+
40+
@withCUDA
41+
def test_radius_max_num_neighbors(device: torch.device) -> None:
42+
x = torch.randn(50, 3, device=device)
43+
y = torch.randn(10, 3, device=device)
44+
45+
out = pyg_lib.ops.radius(x, y, r=100.0, max_num_neighbors=5)
46+
# Each query should have at most 5 neighbors
47+
for i in range(y.size(0)):
48+
assert (out[0] == i).sum() <= 5
49+
50+
51+
@withCUDA
52+
def test_radius_batched(device: torch.device) -> None:
53+
x = torch.randn(20, 3, device=device)
54+
y = torch.randn(15, 3, device=device)
55+
ptr_x = torch.tensor([0, 10, 20], dtype=torch.long, device=device)
56+
ptr_y = torch.tensor([0, 8, 15], dtype=torch.long, device=device)
57+
58+
out = pyg_lib.ops.radius(x, y, r=5.0, ptr_x=ptr_x, ptr_y=ptr_y)
59+
60+
# Batch 0 queries should only reference batch 0 refs
61+
batch0_mask = out[0] < 8
62+
assert (out[1, batch0_mask] < 10).all()
63+
64+
# Batch 1 queries should only reference batch 1 refs
65+
batch1_mask = out[0] >= 8
66+
assert (out[1, batch1_mask] >= 10).all()
67+
68+
69+
@withCUDA
70+
def test_radius_ignore_same_index(device: torch.device) -> None:
71+
x = torch.randn(10, 3, device=device)
72+
73+
out = pyg_lib.ops.radius(x, x, r=100.0, max_num_neighbors=100,
74+
ignore_same_index=True)
75+
# No self-loops
76+
assert (out[0] != out[1]).all()

0 commit comments

Comments
 (0)