Skip to content

Commit 107c05f

Browse files
committed
Add nearest dispatch, CPU and CUDA kernels
New C++ CPU kernel (brute-force pairwise + argmin) replacing the original scipy fallback. CUDA kernel uses shared-memory argmin reduction with 1 block per query point. Tests included.
1 parent d1c8562 commit 107c05f

File tree

5 files changed

+299
-0
lines changed

5 files changed

+299
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "../nearest.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor nearest_kernel(const at::Tensor& x,
12+
const at::Tensor& y,
13+
const std::optional<at::Tensor>& ptr_x,
14+
const std::optional<at::Tensor>& ptr_y) {
15+
auto out = at::empty({x.size(0)}, x.options().dtype(at::kLong));
16+
17+
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_cpu", [&] {
18+
auto x_data = x.data_ptr<scalar_t>();
19+
auto y_data = y.data_ptr<scalar_t>();
20+
auto out_data = out.data_ptr<int64_t>();
21+
auto dim = x.size(1);
22+
23+
if (!ptr_x.has_value()) {
24+
for (int64_t i = 0; i < x.size(0); i++) {
25+
scalar_t best_dist = std::numeric_limits<scalar_t>::max();
26+
int64_t best_idx = 0;
27+
for (int64_t j = 0; j < y.size(0); j++) {
28+
scalar_t dist = 0;
29+
for (int64_t d = 0; d < dim; d++) {
30+
scalar_t diff = x_data[i * dim + d] - y_data[j * dim + d];
31+
dist += diff * diff;
32+
}
33+
if (dist < best_dist) {
34+
best_dist = dist;
35+
best_idx = j;
36+
}
37+
}
38+
out_data[i] = best_idx;
39+
}
40+
} else {
41+
auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
42+
auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();
43+
auto num_batches = ptr_x.value().size(0) - 1;
44+
45+
for (int64_t b = 0; b < num_batches; b++) {
46+
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
47+
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
48+
49+
for (int64_t i = x_start; i < x_end; i++) {
50+
scalar_t best_dist = std::numeric_limits<scalar_t>::max();
51+
int64_t best_idx = y_start;
52+
for (int64_t j = y_start; j < y_end; j++) {
53+
scalar_t dist = 0;
54+
for (int64_t d = 0; d < dim; d++) {
55+
scalar_t diff = x_data[i * dim + d] - y_data[j * dim + d];
56+
dist += diff * diff;
57+
}
58+
if (dist < best_dist) {
59+
best_dist = dist;
60+
best_idx = j;
61+
}
62+
}
63+
out_data[i] = best_idx;
64+
}
65+
}
66+
}
67+
});
68+
69+
return out;
70+
}
71+
72+
} // namespace
73+
74+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
75+
m.impl(TORCH_SELECTIVE_NAME("pyg::nearest"), TORCH_FN(nearest_kernel));
76+
}
77+
78+
} // namespace ops
79+
} // namespace pyg
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include "../nearest.h"
2+
#include "utils.cuh"
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <torch/library.h>
7+
8+
namespace pyg {
9+
namespace ops {
10+
11+
namespace {
12+
13+
#define NEAREST_THREADS 1024
14+
15+
template <typename scalar_t>
16+
__global__ void nearest_cuda_kernel(const scalar_t* __restrict__ x,
17+
const scalar_t* __restrict__ y,
18+
const int64_t* __restrict__ ptr_x,
19+
const int64_t* __restrict__ ptr_y,
20+
int64_t* __restrict__ out,
21+
int64_t batch_size,
22+
int64_t dim) {
23+
const int64_t thread_idx = threadIdx.x;
24+
const int64_t n_x = blockIdx.x;
25+
26+
int64_t batch_idx = 0;
27+
for (int64_t b = 0; b < batch_size; b++) {
28+
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
29+
batch_idx = b;
30+
break;
31+
}
32+
}
33+
34+
const int64_t y_start = ptr_y[batch_idx];
35+
const int64_t y_end = ptr_y[batch_idx + 1];
36+
37+
__shared__ scalar_t best_dist[NEAREST_THREADS];
38+
__shared__ int64_t best_dist_idx[NEAREST_THREADS];
39+
40+
scalar_t best = (scalar_t)1e38;
41+
int64_t best_idx = y_start;
42+
for (int64_t n_y = y_start + thread_idx; n_y < y_end;
43+
n_y += NEAREST_THREADS) {
44+
scalar_t dist = 0;
45+
for (int64_t d = 0; d < dim; d++) {
46+
scalar_t diff = x[n_x * dim + d] - y[n_y * dim + d];
47+
dist += diff * diff;
48+
}
49+
50+
if (scalar_lt(dist, best)) {
51+
best = dist;
52+
best_idx = n_y;
53+
}
54+
}
55+
56+
best_dist[thread_idx] = best;
57+
best_dist_idx[thread_idx] = best_idx;
58+
59+
for (int64_t u = 0; (1 << u) < NEAREST_THREADS; u++) {
60+
__syncthreads();
61+
if (thread_idx < (NEAREST_THREADS >> (u + 1))) {
62+
int64_t idx_1 = (thread_idx * 2) << u;
63+
int64_t idx_2 = (thread_idx * 2 + 1) << u;
64+
if (scalar_gt(best_dist[idx_1], best_dist[idx_2])) {
65+
best_dist[idx_1] = best_dist[idx_2];
66+
best_dist_idx[idx_1] = best_dist_idx[idx_2];
67+
}
68+
}
69+
}
70+
71+
__syncthreads();
72+
if (thread_idx == 0) {
73+
out[n_x] = best_dist_idx[0];
74+
}
75+
}
76+
77+
at::Tensor nearest_cuda(const at::Tensor& x,
78+
const at::Tensor& y,
79+
const std::optional<at::Tensor>& ptr_x,
80+
const std::optional<at::Tensor>& ptr_y) {
81+
TORCH_CHECK(x.is_cuda() && y.is_cuda(), "Inputs must be CUDA tensors");
82+
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(),
83+
"Inputs must be contiguous");
84+
85+
std::optional<at::Tensor> ptr_x_v = ptr_x;
86+
std::optional<at::Tensor> ptr_y_v = ptr_y;
87+
88+
if (!ptr_x_v.has_value())
89+
ptr_x_v =
90+
at::arange(0, x.size(0) + 1, x.size(0), x.options().dtype(at::kLong));
91+
if (!ptr_y_v.has_value())
92+
ptr_y_v =
93+
at::arange(0, y.size(0) + 1, y.size(0), y.options().dtype(at::kLong));
94+
95+
auto out = at::empty({x.size(0)}, ptr_x_v.value().options());
96+
97+
auto stream = at::cuda::getCurrentCUDAStream();
98+
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_cuda", [&] {
99+
nearest_cuda_kernel<scalar_t><<<x.size(0), NEAREST_THREADS, 0, stream>>>(
100+
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
101+
ptr_x_v.value().data_ptr<int64_t>(),
102+
ptr_y_v.value().data_ptr<int64_t>(), out.data_ptr<int64_t>(),
103+
ptr_x_v.value().size(0) - 1, x.size(1));
104+
C10_CUDA_KERNEL_LAUNCH_CHECK();
105+
});
106+
107+
return out;
108+
}
109+
110+
} // namespace
111+
112+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
113+
m.impl(TORCH_SELECTIVE_NAME("pyg::nearest"), TORCH_FN(nearest_cuda));
114+
}
115+
116+
} // namespace ops
117+
} // namespace pyg

pyg_lib/csrc/ops/nearest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "nearest.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor nearest(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y) {
13+
at::TensorArg x_arg{x, "x", 0};
14+
at::TensorArg y_arg{y, "y", 1};
15+
at::CheckedFrom c{"nearest"};
16+
17+
at::checkAllDefined(c, {x_arg, y_arg});
18+
19+
auto x_c = x.view({x.size(0), -1}).contiguous();
20+
auto y_c = y.view({y.size(0), -1}).contiguous();
21+
22+
TORCH_CHECK(x_c.size(1) == y_c.size(1),
23+
"x and y must have the same feature dim");
24+
25+
static auto op = c10::Dispatcher::singleton()
26+
.findSchemaOrThrow("pyg::nearest", "")
27+
.typed<decltype(nearest)>();
28+
return op.call(x_c, y_c, ptr_x, ptr_y);
29+
}
30+
31+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
32+
m.def(TORCH_SELECTIVE_SCHEMA(
33+
"pyg::nearest(Tensor x, Tensor y, Tensor? ptr_x=None, "
34+
"Tensor? ptr_y=None) -> Tensor"));
35+
}
36+
37+
} // namespace ops
38+
} // namespace pyg

pyg_lib/csrc/ops/nearest.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor nearest(const at::Tensor& x,
10+
const at::Tensor& y,
11+
const std::optional<at::Tensor>& ptr_x,
12+
const std::optional<at::Tensor>& ptr_y);
13+
14+
} // namespace ops
15+
} // namespace pyg

test/ops/test_nearest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import torch
3+
4+
import pyg_lib
5+
from pyg_lib.testing import withCUDA
6+
7+
8+
@withCUDA
9+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
10+
def test_nearest_basic(dtype: torch.dtype, device: torch.device) -> None:
11+
x = torch.tensor([[0.0, 0.0], [3.0, 0.0]], dtype=dtype, device=device)
12+
y = torch.tensor([[1.0, 0.0], [2.0, 0.0]], dtype=dtype, device=device)
13+
14+
out = pyg_lib.ops.nearest(x, y)
15+
assert out.shape == (2, )
16+
assert out[0].item() == 0 # x[0]=(0,0) nearest to y[0]=(1,0)
17+
assert out[1].item() == 1 # x[1]=(3,0) nearest to y[1]=(2,0)
18+
19+
20+
@withCUDA
21+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
22+
def test_nearest_correctness(dtype: torch.dtype, device: torch.device) -> None:
23+
x = torch.randn(20, 5, dtype=dtype, device=device)
24+
y = torch.randn(15, 5, dtype=dtype, device=device)
25+
26+
out = pyg_lib.ops.nearest(x, y)
27+
28+
# Reference: cdist + argmin
29+
dists = torch.cdist(x.float(), y.float())
30+
expected = dists.argmin(dim=1)
31+
assert torch.equal(out, expected.to(out.device))
32+
33+
34+
@withCUDA
35+
def test_nearest_batched(device: torch.device) -> None:
36+
x = torch.randn(20, 3, device=device)
37+
y = torch.randn(15, 3, device=device)
38+
ptr_x = torch.tensor([0, 10, 20], dtype=torch.long, device=device)
39+
ptr_y = torch.tensor([0, 8, 15], dtype=torch.long, device=device)
40+
41+
out = pyg_lib.ops.nearest(x, y, ptr_x=ptr_x, ptr_y=ptr_y)
42+
assert out.shape == (20, )
43+
44+
# Batch 0 results should be in [0, 8)
45+
assert (out[:10] >= 0).all()
46+
assert (out[:10] < 8).all()
47+
48+
# Batch 1 results should be in [8, 15)
49+
assert (out[10:] >= 8).all()
50+
assert (out[10:] < 15).all()

0 commit comments

Comments
 (0)