Skip to content

Commit 4949ae5

Browse files
committed
Add grid_cluster dispatch + CPU kernel
Port grid voxelization from pytorch_cluster into pyg-lib. Each point is assigned a 1D cluster index based on its quantized voxel position: floor((pos - start) / size), then flattened with cumulative voxel counts.
1 parent 9be1bb2 commit 4949ae5

File tree

5 files changed

+200
-0
lines changed

5 files changed

+200
-0
lines changed

pyg_lib/csrc/ops/cluster.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "cluster.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor grid_cluster(const at::Tensor& pos,
10+
const at::Tensor& size,
11+
const std::optional<at::Tensor>& start,
12+
const std::optional<at::Tensor>& end) {
13+
at::TensorArg pos_arg{pos, "pos", 0};
14+
at::TensorArg size_arg{size, "size", 1};
15+
at::CheckedFrom c{"grid_cluster"};
16+
17+
at::checkAllDefined(c, {pos_arg, size_arg});
18+
19+
auto pos_2d = pos.view({pos.size(0), -1}).contiguous();
20+
auto size_c = size.contiguous();
21+
22+
TORCH_CHECK(size_c.numel() == pos_2d.size(1),
23+
"size.numel() must equal pos dimension count");
24+
25+
if (start.has_value()) {
26+
TORCH_CHECK(start.value().numel() == pos_2d.size(1),
27+
"start.numel() must equal pos dimension count");
28+
}
29+
if (end.has_value()) {
30+
TORCH_CHECK(end.value().numel() == pos_2d.size(1),
31+
"end.numel() must equal pos dimension count");
32+
}
33+
34+
static auto op = c10::Dispatcher::singleton()
35+
.findSchemaOrThrow("pyg::grid_cluster", "")
36+
.typed<decltype(grid_cluster)>();
37+
return op.call(pos_2d, size_c, start, end);
38+
}
39+
40+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
41+
m.def(TORCH_SELECTIVE_SCHEMA(
42+
"pyg::grid_cluster(Tensor pos, Tensor size, "
43+
"Tensor? start=None, Tensor? end=None) -> Tensor"));
44+
}
45+
46+
} // namespace ops
47+
} // namespace pyg

pyg_lib/csrc/ops/cluster.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor grid_cluster(const at::Tensor& pos,
10+
const at::Tensor& size,
11+
const std::optional<at::Tensor>& start,
12+
const std::optional<at::Tensor>& end);
13+
14+
} // namespace ops
15+
} // namespace pyg
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "../cluster.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor grid_cluster_kernel(const at::Tensor& pos,
12+
const at::Tensor& size,
13+
const std::optional<at::Tensor>& optional_start,
14+
const std::optional<at::Tensor>& optional_end) {
15+
auto N = pos.size(0);
16+
auto D = pos.size(1);
17+
18+
at::Tensor start;
19+
if (optional_start.has_value())
20+
start = optional_start.value().contiguous();
21+
else
22+
start = std::get<0>(pos.min(0));
23+
24+
at::Tensor end;
25+
if (optional_end.has_value())
26+
end = optional_end.value().contiguous();
27+
else
28+
end = std::get<0>(pos.max(0));
29+
30+
auto pos_shifted = pos - start.unsqueeze(0);
31+
32+
auto num_voxels =
33+
(end - start).div(size, /*rounding_mode=*/"trunc").to(at::kLong) + 1;
34+
num_voxels = num_voxels.cumprod(0);
35+
num_voxels = at::cat({at::ones({1}, num_voxels.options()), num_voxels}, 0);
36+
num_voxels = num_voxels.narrow(0, 0, D);
37+
38+
auto out = pos_shifted.div(size.view({1, -1}), /*rounding_mode=*/"trunc")
39+
.to(at::kLong);
40+
out *= num_voxels.view({1, -1});
41+
out = out.sum(1);
42+
43+
return out;
44+
}
45+
46+
} // namespace
47+
48+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
49+
m.impl(TORCH_SELECTIVE_NAME("pyg::grid_cluster"),
50+
TORCH_FN(grid_cluster_kernel));
51+
}
52+
53+
} // namespace ops
54+
} // namespace pyg

pyg_lib/ops/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,32 @@ def spline_weighting(
392392
return torch.ops.pyg.spline_weighting(x, weight, basis, weight_index)
393393

394394

395+
def grid_cluster(
396+
pos: Tensor,
397+
size: Tensor,
398+
start: Optional[Tensor] = None,
399+
end: Optional[Tensor] = None,
400+
) -> Tensor:
401+
r"""Clusters all points in :obj:`pos` into voxels of size :obj:`size`.
402+
403+
Each point is assigned a cluster index based on which voxel it falls into.
404+
The voxel grid is defined by the :obj:`size` parameter and optionally
405+
bounded by :obj:`start` and :obj:`end`.
406+
407+
Args:
408+
pos: Point positions of shape :obj:`[N, D]`.
409+
size: Voxel size in each dimension of shape :obj:`[D]`.
410+
start: Start of the voxel grid in each dimension of shape :obj:`[D]`.
411+
If :obj:`None`, uses the minimum of :obj:`pos`.
412+
end: End of the voxel grid in each dimension of shape :obj:`[D]`.
413+
If :obj:`None`, uses the maximum of :obj:`pos`.
414+
415+
Returns:
416+
Cluster index for each point of shape :obj:`[N]`.
417+
"""
418+
return torch.ops.pyg.grid_cluster(pos, size, start, end)
419+
420+
395421
__all__ = [
396422
'grouped_matmul',
397423
'segment_matmul',
@@ -403,4 +429,5 @@ def spline_weighting(
403429
'softmax_csr',
404430
'spline_basis',
405431
'spline_weighting',
432+
'grid_cluster',
406433
]

test/ops/test_cluster.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import torch
3+
4+
import pyg_lib
5+
6+
7+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
8+
def test_grid_cluster_2d(dtype: torch.dtype) -> None:
9+
pos = torch.tensor(
10+
[[0.0, 0.0], [0.1, 0.1], [0.5, 0.5], [1.0, 1.0], [1.1, 1.1]],
11+
dtype=dtype)
12+
size = torch.tensor([0.5, 0.5], dtype=dtype)
13+
14+
out = pyg_lib.ops.grid_cluster(pos, size)
15+
16+
# Points (0,0) and (0.1,0.1) should be in the same cluster
17+
assert out[0] == out[1]
18+
# Point (0.5,0.5) should be in a different cluster from (0,0)
19+
assert out[0] != out[2]
20+
# Points (1.0,1.0) and (1.1,1.1) should be in the same cluster
21+
assert out[3] == out[4]
22+
23+
24+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
25+
def test_grid_cluster_3d(dtype: torch.dtype) -> None:
26+
pos = torch.tensor([[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [1.0, 1.0, 1.0]],
27+
dtype=dtype)
28+
size = torch.tensor([0.5, 0.5, 0.5], dtype=dtype)
29+
30+
out = pyg_lib.ops.grid_cluster(pos, size)
31+
32+
assert out[0] == out[1]
33+
assert out[0] != out[2]
34+
35+
36+
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
37+
def test_grid_cluster_with_start_end(dtype: torch.dtype) -> None:
38+
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], dtype=dtype)
39+
size = torch.tensor([0.5, 0.5], dtype=dtype)
40+
start = torch.tensor([0.0, 0.0], dtype=dtype)
41+
end = torch.tensor([1.0, 1.0], dtype=dtype)
42+
43+
out = pyg_lib.ops.grid_cluster(pos, size, start, end)
44+
45+
assert out.shape == (3, )
46+
assert out.dtype == torch.long
47+
48+
49+
def test_grid_cluster_defaults_match_explicit() -> None:
50+
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
51+
size = torch.tensor([0.5, 0.5])
52+
53+
out_default = pyg_lib.ops.grid_cluster(pos, size)
54+
out_explicit = pyg_lib.ops.grid_cluster(pos, size, start=pos.min(0).values,
55+
end=pos.max(0).values)
56+
57+
assert torch.equal(out_default, out_explicit)

0 commit comments

Comments
 (0)