Skip to content

Commit 6ab0a1f

Browse files
committed
Add fps CUDA kernel
Port CUDA farthest point sampling kernel with shared-memory argmax reduction. Uses explicit scalar_gt/scalar_lt/scalar_min helpers to avoid NVCC operator overload ambiguity with c10::SymInt.
1 parent a9f95f8 commit 6ab0a1f

File tree

2 files changed

+164
-18
lines changed

2 files changed

+164
-18
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#include "../fps.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <torch/library.h>
6+
7+
namespace pyg {
8+
namespace ops {
9+
10+
namespace {
11+
12+
#define FPS_THREADS 256
13+
14+
// Explicit non-template comparison/min functions to avoid NVCC ambiguous
15+
// operator overload errors from c10::SymInt (error #3343).
16+
__device__ __forceinline__ bool scalar_gt(float a, float b) {
17+
return a > b;
18+
}
19+
__device__ __forceinline__ bool scalar_gt(double a, double b) {
20+
return a > b;
21+
}
22+
__device__ __forceinline__ bool scalar_lt(float a, float b) {
23+
return a < b;
24+
}
25+
__device__ __forceinline__ bool scalar_lt(double a, double b) {
26+
return a < b;
27+
}
28+
__device__ __forceinline__ float scalar_min(float a, float b) {
29+
return fminf(a, b);
30+
}
31+
__device__ __forceinline__ double scalar_min(double a, double b) {
32+
return fmin(a, b);
33+
}
34+
35+
template <typename scalar_t>
36+
__global__ void fps_cuda_kernel(const scalar_t* src,
37+
const int64_t* ptr,
38+
const int64_t* out_ptr,
39+
const int64_t* start,
40+
scalar_t* dist,
41+
int64_t* out,
42+
int64_t dim) {
43+
const int64_t thread_idx = threadIdx.x;
44+
const int64_t batch_idx = blockIdx.x;
45+
46+
const int64_t start_idx = ptr[batch_idx];
47+
const int64_t end_idx = ptr[batch_idx + 1];
48+
49+
__shared__ scalar_t best_dist[FPS_THREADS];
50+
__shared__ int64_t best_dist_idx[FPS_THREADS];
51+
52+
if (thread_idx == 0) {
53+
out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
54+
}
55+
56+
for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
57+
__syncthreads();
58+
int64_t old = out[m - 1];
59+
60+
scalar_t best = (scalar_t)-1.;
61+
int64_t best_idx = 0;
62+
63+
for (int64_t n = start_idx + thread_idx; n < end_idx; n += FPS_THREADS) {
64+
scalar_t tmp, dd = (scalar_t)0.;
65+
for (int64_t d = 0; d < dim; d++) {
66+
tmp = src[dim * old + d] - src[dim * n + d];
67+
dd += tmp * tmp;
68+
}
69+
dd = scalar_min(dist[n], dd);
70+
dist[n] = dd;
71+
if (scalar_gt(dd, best)) {
72+
best = dd;
73+
best_idx = n;
74+
}
75+
}
76+
77+
best_dist[thread_idx] = best;
78+
best_dist_idx[thread_idx] = best_idx;
79+
80+
for (int64_t i = 1; i < FPS_THREADS; i *= 2) {
81+
__syncthreads();
82+
if ((thread_idx + i) < FPS_THREADS &&
83+
scalar_lt(best_dist[thread_idx], best_dist[thread_idx + i])) {
84+
best_dist[thread_idx] = best_dist[thread_idx + i];
85+
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
86+
}
87+
}
88+
89+
__syncthreads();
90+
if (thread_idx == 0) {
91+
out[m] = best_dist_idx[0];
92+
}
93+
}
94+
}
95+
96+
at::Tensor fps_cuda(const at::Tensor& src,
97+
const at::Tensor& ptr,
98+
double ratio,
99+
bool random_start) {
100+
TORCH_CHECK(src.is_cuda(), "src must be a CUDA tensor");
101+
TORCH_CHECK(src.is_contiguous(), "src must be contiguous");
102+
TORCH_CHECK(ptr.is_cuda(), "ptr must be a CUDA tensor");
103+
104+
int64_t batch_size = ptr.numel() - 1;
105+
int64_t D = src.size(1);
106+
107+
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
108+
auto out_ptr = deg.to(at::kFloat) * ratio;
109+
out_ptr = out_ptr.ceil().to(at::kLong).cumsum(0);
110+
out_ptr = at::cat({at::zeros({1}, ptr.options()), out_ptr}, 0);
111+
112+
at::Tensor start;
113+
if (random_start) {
114+
start = at::rand({batch_size}, src.options());
115+
start = (start * deg.to(at::kFloat)).to(at::kLong);
116+
} else {
117+
start = at::zeros({batch_size}, ptr.options());
118+
}
119+
120+
auto dist = at::full({src.size(0)}, 5e4, src.options());
121+
122+
int64_t out_total;
123+
cudaMemcpy(&out_total, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
124+
cudaMemcpyDeviceToHost);
125+
auto out = at::empty({out_total}, out_ptr.options());
126+
127+
auto stream = at::cuda::getCurrentCUDAStream();
128+
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_cuda", [&] {
129+
fps_cuda_kernel<scalar_t><<<batch_size, FPS_THREADS, 0, stream>>>(
130+
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
131+
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
132+
dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), D);
133+
C10_CUDA_KERNEL_LAUNCH_CHECK();
134+
});
135+
136+
return out;
137+
}
138+
139+
} // namespace
140+
141+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
142+
m.impl(TORCH_SELECTIVE_NAME("pyg::fps"), TORCH_FN(fps_cuda));
143+
}
144+
145+
} // namespace ops
146+
} // namespace pyg

test/ops/test_fps.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,29 @@
22
import torch
33

44
import pyg_lib
5+
from pyg_lib.testing import withCUDA
56

67

8+
@withCUDA
79
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
8-
def test_fps_output_size(dtype: torch.dtype) -> None:
10+
def test_fps_output_size(dtype: torch.dtype, device: torch.device) -> None:
911
N, D = 20, 3
10-
src = torch.randn(N, D, dtype=dtype)
11-
ptr = torch.tensor([0, N], dtype=torch.long)
12+
src = torch.randn(N, D, dtype=dtype, device=device)
13+
ptr = torch.tensor([0, N], dtype=torch.long, device=device)
1214

1315
out = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
1416
assert out.shape == (10, )
1517
assert out.dtype == torch.long
16-
# All indices should be within range:
1718
assert out.min() >= 0
1819
assert out.max() < N
1920

2021

22+
@withCUDA
2123
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
22-
def test_fps_farthest_property(dtype: torch.dtype) -> None:
23-
# After FPS, the minimum pairwise distance between selected points
24-
# should be >= the greedy guarantee.
25-
src = torch.randn(50, 3, dtype=dtype)
26-
ptr = torch.tensor([0, 50], dtype=torch.long)
24+
def test_fps_farthest_property(dtype: torch.dtype,
25+
device: torch.device) -> None:
26+
src = torch.randn(50, 3, dtype=dtype, device=device)
27+
ptr = torch.tensor([0, 50], dtype=torch.long, device=device)
2728

2829
out = pyg_lib.ops.fps(src, ptr, ratio=0.2, random_start=False)
2930
selected = src[out]
@@ -33,25 +34,24 @@ def test_fps_farthest_property(dtype: torch.dtype) -> None:
3334
assert min_dist > 0
3435

3536

36-
def test_fps_multi_batch() -> None:
37-
src = torch.randn(30, 3)
38-
ptr = torch.tensor([0, 10, 30], dtype=torch.long)
37+
@withCUDA
38+
def test_fps_multi_batch(device: torch.device) -> None:
39+
src = torch.randn(30, 3, device=device)
40+
ptr = torch.tensor([0, 10, 30], dtype=torch.long, device=device)
3941

4042
out = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
4143
# Batch 0: ceil(10 * 0.5) = 5, Batch 1: ceil(20 * 0.5) = 10
4244
assert out.shape == (15, )
43-
# First 5 indices in batch 0:
4445
assert (out[:5] < 10).all()
4546
assert (out[:5] >= 0).all()
46-
# Next 10 in batch 1:
4747
assert (out[5:] >= 10).all()
4848
assert (out[5:] < 30).all()
4949

5050

51-
def test_fps_random_start() -> None:
52-
src = torch.randn(20, 3)
53-
ptr = torch.tensor([0, 20], dtype=torch.long)
51+
@withCUDA
52+
def test_fps_random_start(device: torch.device) -> None:
53+
src = torch.randn(20, 3, device=device)
54+
ptr = torch.tensor([0, 20], dtype=torch.long, device=device)
5455

5556
out_det = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
56-
# Deterministic: first selected index is always 0
5757
assert out_det[0] == 0

0 commit comments

Comments
 (0)