Skip to content

Commit 31f480b

Browse files
committed
Add grid_cluster CUDA kernel
Port CUDA grid voxelization kernel. One thread per point computes the flattened voxel index. Tests verify CPU/CUDA parity.
1 parent 2c53091 commit 31f480b

File tree

2 files changed

+104
-16
lines changed

2 files changed

+104
-16
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "../cluster.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <torch/library.h>
6+
7+
namespace pyg {
8+
namespace ops {
9+
10+
namespace {
11+
12+
#define THREADS 1024
13+
#define BLOCKS(N) ((N) + THREADS - 1) / THREADS
14+
15+
template <typename scalar_t>
16+
__global__ void grid_cluster_cuda_kernel(const scalar_t* pos,
17+
const scalar_t* size,
18+
const scalar_t* start,
19+
const scalar_t* end,
20+
int64_t* out,
21+
int64_t D,
22+
int64_t numel) {
23+
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
24+
25+
if (thread_idx < numel) {
26+
int64_t c = 0, k = 1;
27+
for (int64_t d = 0; d < D; d++) {
28+
scalar_t p = pos[thread_idx * D + d] - start[d];
29+
c += (int64_t)(p / size[d]) * k;
30+
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
31+
}
32+
out[thread_idx] = c;
33+
}
34+
}
35+
36+
at::Tensor grid_cluster_cuda(const at::Tensor& pos,
37+
const at::Tensor& size,
38+
const std::optional<at::Tensor>& optional_start,
39+
const std::optional<at::Tensor>& optional_end) {
40+
TORCH_CHECK(pos.is_cuda(), "pos must be a CUDA tensor");
41+
TORCH_CHECK(pos.is_contiguous(), "pos must be contiguous");
42+
43+
auto N = pos.size(0);
44+
auto D = pos.size(1);
45+
46+
at::Tensor start;
47+
if (optional_start.has_value())
48+
start = optional_start.value().contiguous();
49+
else
50+
start = std::get<0>(pos.min(0));
51+
52+
at::Tensor end;
53+
if (optional_end.has_value())
54+
end = optional_end.value().contiguous();
55+
else
56+
end = std::get<0>(pos.max(0));
57+
58+
auto out = at::empty({N}, pos.options().dtype(at::kLong));
59+
60+
auto stream = at::cuda::getCurrentCUDAStream();
61+
AT_DISPATCH_FLOATING_TYPES_AND2(
62+
at::ScalarType::Half, at::ScalarType::BFloat16, pos.scalar_type(),
63+
"grid_cluster_cuda", [&] {
64+
grid_cluster_cuda_kernel<scalar_t><<<BLOCKS(N), THREADS, 0, stream>>>(
65+
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
66+
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
67+
out.data_ptr<int64_t>(), D, N);
68+
C10_CUDA_KERNEL_LAUNCH_CHECK();
69+
});
70+
71+
return out;
72+
}
73+
74+
} // namespace
75+
76+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
77+
m.impl(TORCH_SELECTIVE_NAME("pyg::grid_cluster"),
78+
TORCH_FN(grid_cluster_cuda));
79+
}
80+
81+
} // namespace ops
82+
} // namespace pyg

test/ops/test_cluster.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import torch
33

44
import pyg_lib
5+
from pyg_lib.testing import withCUDA
56

67

8+
@withCUDA
79
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
8-
def test_grid_cluster_2d(dtype: torch.dtype) -> None:
10+
def test_grid_cluster_2d(dtype: torch.dtype, device: torch.device) -> None:
911
pos = torch.tensor(
1012
[[0.0, 0.0], [0.1, 0.1], [0.5, 0.5], [1.0, 1.0], [1.1, 1.1]],
11-
dtype=dtype)
12-
size = torch.tensor([0.5, 0.5], dtype=dtype)
13+
dtype=dtype, device=device)
14+
size = torch.tensor([0.5, 0.5], dtype=dtype, device=device)
1315

1416
out = pyg_lib.ops.grid_cluster(pos, size)
1517

@@ -21,37 +23,41 @@ def test_grid_cluster_2d(dtype: torch.dtype) -> None:
2123
assert out[3] == out[4]
2224

2325

26+
@withCUDA
2427
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
25-
def test_grid_cluster_3d(dtype: torch.dtype) -> None:
28+
def test_grid_cluster_3d(dtype: torch.dtype, device: torch.device) -> None:
2629
pos = torch.tensor([[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [1.0, 1.0, 1.0]],
27-
dtype=dtype)
28-
size = torch.tensor([0.5, 0.5, 0.5], dtype=dtype)
30+
dtype=dtype, device=device)
31+
size = torch.tensor([0.5, 0.5, 0.5], dtype=dtype, device=device)
2932

3033
out = pyg_lib.ops.grid_cluster(pos, size)
3134

3235
assert out[0] == out[1]
3336
assert out[0] != out[2]
3437

3538

39+
@withCUDA
3640
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
37-
def test_grid_cluster_with_start_end(dtype: torch.dtype) -> None:
38-
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], dtype=dtype)
39-
size = torch.tensor([0.5, 0.5], dtype=dtype)
40-
start = torch.tensor([0.0, 0.0], dtype=dtype)
41-
end = torch.tensor([1.0, 1.0], dtype=dtype)
41+
def test_grid_cluster_with_start_end(dtype: torch.dtype,
42+
device: torch.device) -> None:
43+
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], dtype=dtype,
44+
device=device)
45+
size = torch.tensor([0.5, 0.5], dtype=dtype, device=device)
46+
start = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
47+
end = torch.tensor([1.0, 1.0], dtype=dtype, device=device)
4248

4349
out = pyg_lib.ops.grid_cluster(pos, size, start, end)
4450

4551
assert out.shape == (3, )
4652
assert out.dtype == torch.long
4753

4854

49-
def test_grid_cluster_defaults_match_explicit() -> None:
55+
@withCUDA
56+
def test_grid_cluster_cpu_cuda_parity(device: torch.device) -> None:
5057
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
5158
size = torch.tensor([0.5, 0.5])
5259

53-
out_default = pyg_lib.ops.grid_cluster(pos, size)
54-
out_explicit = pyg_lib.ops.grid_cluster(pos, size, start=pos.min(0).values,
55-
end=pos.max(0).values)
60+
out_cpu = pyg_lib.ops.grid_cluster(pos, size)
61+
out_dev = pyg_lib.ops.grid_cluster(pos.to(device), size.to(device))
5662

57-
assert torch.equal(out_default, out_explicit)
63+
assert torch.equal(out_cpu, out_dev.cpu())

0 commit comments

Comments
 (0)