Skip to content

Commit b99c438

Browse files
authored
Add fps dispatch + CPU kernel (#587)
1 parent e812618 commit b99c438

File tree

5 files changed

+242
-0
lines changed

5 files changed

+242
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "../fps.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/Parallel.h>
5+
#include <torch/library.h>
6+
7+
namespace pyg {
8+
namespace ops {
9+
10+
namespace {
11+
12+
at::Tensor fps_kernel(const at::Tensor& src,
13+
const at::Tensor& ptr,
14+
double ratio,
15+
bool random_start) {
16+
auto N = src.size(0);
17+
auto D = src.size(1);
18+
auto batch_size = ptr.numel() - 1;
19+
20+
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
21+
auto out_ptr = deg.to(at::kFloat) * ratio;
22+
out_ptr = out_ptr.ceil().to(at::kLong).cumsum(0);
23+
24+
auto out = at::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());
25+
26+
auto ptr_data = ptr.data_ptr<int64_t>();
27+
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
28+
auto out_data = out.data_ptr<int64_t>();
29+
30+
int64_t grain_size = 1;
31+
at::parallel_for(0, batch_size, grain_size, [&](int64_t begin, int64_t end) {
32+
for (int64_t b = begin; b < end; b++) {
33+
auto src_start = ptr_data[b];
34+
auto src_end = ptr_data[b + 1];
35+
auto out_start = b == 0 ? 0 : out_ptr_data[b - 1];
36+
auto out_end = out_ptr_data[b];
37+
38+
auto y = src.narrow(0, src_start, src_end - src_start);
39+
40+
int64_t start_idx = 0;
41+
if (random_start)
42+
start_idx = rand() % y.size(0);
43+
44+
out_data[out_start] = src_start + start_idx;
45+
auto dist = (y - y[start_idx]).pow_(2).sum(1);
46+
47+
for (int64_t i = 1; i < out_end - out_start; i++) {
48+
int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
49+
out_data[out_start + i] = src_start + argmax;
50+
dist = at::min(dist, (y - y[argmax]).pow_(2).sum(1));
51+
}
52+
}
53+
});
54+
55+
return out;
56+
}
57+
58+
} // namespace
59+
60+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
61+
m.impl(TORCH_SELECTIVE_NAME("pyg::fps"), TORCH_FN(fps_kernel));
62+
}
63+
64+
} // namespace ops
65+
} // namespace pyg

pyg_lib/csrc/ops/fps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "fps.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor fps(const at::Tensor& src,
10+
const at::Tensor& ptr,
11+
double ratio,
12+
bool random_start) {
13+
at::TensorArg src_arg{src, "src", 0};
14+
at::TensorArg ptr_arg{ptr, "ptr", 1};
15+
at::CheckedFrom c{"fps"};
16+
17+
at::checkAllDefined(c, {src_arg, ptr_arg});
18+
at::checkDim(c, ptr_arg, 1);
19+
20+
TORCH_CHECK(ratio > 0.0 && ratio <= 1.0, "ratio must be in the range (0, 1]");
21+
22+
auto src_c = src.view({src.size(0), -1}).contiguous();
23+
auto ptr_c = ptr.contiguous();
24+
25+
static auto op = c10::Dispatcher::singleton()
26+
.findSchemaOrThrow("pyg::fps", "")
27+
.typed<decltype(fps)>();
28+
return op.call(src_c, ptr_c, ratio, random_start);
29+
}
30+
31+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
32+
m.def(TORCH_SELECTIVE_SCHEMA(
33+
"pyg::fps(Tensor src, Tensor ptr, float ratio=0.5, "
34+
"bool random_start=True) -> Tensor"));
35+
}
36+
37+
} // namespace ops
38+
} // namespace pyg

pyg_lib/csrc/ops/fps.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor fps(const at::Tensor& src,
10+
const at::Tensor& ptr,
11+
double ratio,
12+
bool random_start);
13+
14+
} // namespace ops
15+
} // namespace pyg

pyg_lib/ops/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,29 @@ def grid_cluster(
418418
return torch.ops.pyg.grid_cluster(pos, size, start, end)
419419

420420

421+
def fps(
422+
src: Tensor,
423+
ptr: Tensor,
424+
ratio: float = 0.5,
425+
random_start: bool = True,
426+
) -> Tensor:
427+
r"""Performs greedy farthest point sampling.
428+
429+
Starting from a random point (or the first point), iteratively selects
430+
the point that is farthest from the already selected set.
431+
432+
Args:
433+
src: Point positions of shape :obj:`[N, D]`.
434+
ptr: Batch boundaries as a CSR pointer of shape :obj:`[B + 1]`.
435+
ratio: Fraction of points to sample from each batch (in :obj:`(0, 1]`).
436+
random_start: If :obj:`True`, starts from a random point.
437+
438+
Returns:
439+
Indices of the sampled points of shape :obj:`[M]`.
440+
"""
441+
return torch.ops.pyg.fps(src, ptr, ratio, random_start)
442+
443+
421444
__all__ = [
422445
'grouped_matmul',
423446
'segment_matmul',
@@ -430,4 +453,5 @@ def grid_cluster(
430453
'spline_basis',
431454
'spline_weighting',
432455
'grid_cluster',
456+
'fps',
433457
]

test/ops/test_fps.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
import torch
3+
4+
import pyg_lib
5+
6+
7+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
8+
def test_fps_output_size(dtype: torch.dtype) -> None:
9+
N, D = 20, 3
10+
src = torch.randn(N, D, dtype=dtype)
11+
ptr = torch.tensor([0, N], dtype=torch.long)
12+
13+
out = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
14+
assert out.shape == (10, )
15+
assert out.dtype == torch.long
16+
# All indices should be within range:
17+
assert out.min() >= 0
18+
assert out.max() < N
19+
20+
21+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
22+
def test_fps_farthest_property(dtype: torch.dtype) -> None:
23+
# After FPS, the minimum pairwise distance between selected points
24+
# should be >= the greedy guarantee.
25+
src = torch.randn(50, 3, dtype=dtype)
26+
ptr = torch.tensor([0, 50], dtype=torch.long)
27+
28+
out = pyg_lib.ops.fps(src, ptr, ratio=0.2, random_start=False)
29+
selected = src[out]
30+
dists = torch.cdist(selected, selected)
31+
dists.fill_diagonal_(float('inf'))
32+
min_dist = dists.min()
33+
assert min_dist > 0
34+
35+
36+
def test_fps_multi_batch() -> None:
37+
src = torch.randn(30, 3)
38+
ptr = torch.tensor([0, 10, 30], dtype=torch.long)
39+
40+
out = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
41+
# Batch 0: ceil(10 * 0.5) = 5, Batch 1: ceil(20 * 0.5) = 10
42+
assert out.shape == (15, )
43+
# First 5 indices in batch 0:
44+
assert (out[:5] < 10).all()
45+
assert (out[:5] >= 0).all()
46+
# Next 10 in batch 1:
47+
assert (out[5:] >= 10).all()
48+
assert (out[5:] < 30).all()
49+
50+
51+
def test_fps_random_start() -> None:
52+
src = torch.randn(20, 3)
53+
ptr = torch.tensor([0, 20], dtype=torch.long)
54+
55+
out_det = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
56+
# Deterministic: first selected index is always 0
57+
assert out_det[0] == 0
58+
59+
60+
def test_fps_ratio_one() -> None:
61+
# ratio=1.0 should return all points.
62+
N = 15
63+
src = torch.randn(N, 3)
64+
ptr = torch.tensor([0, N], dtype=torch.long)
65+
66+
out = pyg_lib.ops.fps(src, ptr, ratio=1.0, random_start=False)
67+
assert out.shape == (N, )
68+
assert set(out.tolist()) == set(range(N))
69+
70+
71+
def test_fps_single_point_batch() -> None:
72+
# Edge case: batch with a single point.
73+
src = torch.randn(1, 3)
74+
ptr = torch.tensor([0, 1], dtype=torch.long)
75+
76+
out = pyg_lib.ops.fps(src, ptr, ratio=1.0, random_start=False)
77+
assert out.shape == (1, )
78+
assert out[0] == 0
79+
80+
81+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
82+
def test_fps_greedy_property(dtype: torch.dtype) -> None:
83+
# Verify the greedy FPS invariant: each selected point (after the first)
84+
# must be the farthest from the already-selected set at the time of its
85+
# selection.
86+
src = torch.randn(30, 3, dtype=dtype)
87+
ptr = torch.tensor([0, 30], dtype=torch.long)
88+
89+
out = pyg_lib.ops.fps(src, ptr, ratio=0.5, random_start=False)
90+
91+
selected = [out[0].item()]
92+
for i in range(1, out.shape[0]):
93+
# Minimum distance from each candidate to the selected set so far:
94+
sel = src[selected]
95+
dists = torch.cdist(src.unsqueeze(0), sel.unsqueeze(0)).squeeze(0)
96+
min_dists = dists.min(dim=1).values
97+
# The point FPS picked should have the maximum min-distance:
98+
expected = min_dists.argmax().item()
99+
assert out[i].item() == expected
100+
selected.append(out[i].item())

0 commit comments

Comments
 (0)