Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions pyg_lib/csrc/ops/cuda/cluster_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "../cluster.h"

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

namespace {

#define THREADS 1024
#define BLOCKS(N) ((N) + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void grid_cluster_cuda_kernel(const scalar_t* pos,
const scalar_t* size,
const scalar_t* start,
const scalar_t* end,
int64_t* out,
int64_t D,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel) {
int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) {
scalar_t p = pos[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
out[thread_idx] = c;
}
}

at::Tensor grid_cluster_cuda(const at::Tensor& pos,
const at::Tensor& size,
const std::optional<at::Tensor>& optional_start,
const std::optional<at::Tensor>& optional_end) {
TORCH_CHECK(pos.is_cuda(), "pos must be a CUDA tensor");
TORCH_CHECK(pos.is_contiguous(), "pos must be contiguous");

auto N = pos.size(0);
auto D = pos.size(1);

at::Tensor start;
if (optional_start.has_value())
start = optional_start.value().contiguous();
else
start = std::get<0>(pos.min(0));

at::Tensor end;
if (optional_end.has_value())
end = optional_end.value().contiguous();
else
end = std::get<0>(pos.max(0));

auto out = at::empty({N}, pos.options().dtype(at::kLong));

auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, pos.scalar_type(),
"grid_cluster_cuda", [&] {
grid_cluster_cuda_kernel<scalar_t><<<BLOCKS(N), THREADS, 0, stream>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), D, N);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

return out;
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::grid_cluster"),
TORCH_FN(grid_cluster_cuda));
}

} // namespace ops
} // namespace pyg
56 changes: 39 additions & 17 deletions test/ops/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import torch

import pyg_lib
from pyg_lib.testing import withCUDA


@pytest.mark.parametrize('dtype', [torch.float, torch.double])
def test_grid_cluster_2d(dtype: torch.dtype) -> None:
@withCUDA
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.bfloat16])
def test_grid_cluster_2d(dtype: torch.dtype, device: torch.device) -> None:
pos = torch.tensor(
[[0.0, 0.0], [0.1, 0.1], [0.5, 0.5], [1.0, 1.0], [1.1, 1.1]],
dtype=dtype)
size = torch.tensor([0.5, 0.5], dtype=dtype)
dtype=dtype, device=device)
size = torch.tensor([0.5, 0.5], dtype=dtype, device=device)

out = pyg_lib.ops.grid_cluster(pos, size)

Expand All @@ -21,37 +23,57 @@ def test_grid_cluster_2d(dtype: torch.dtype) -> None:
assert out[3] == out[4]


@withCUDA
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
def test_grid_cluster_3d(dtype: torch.dtype) -> None:
def test_grid_cluster_3d(dtype: torch.dtype, device: torch.device) -> None:
pos = torch.tensor([[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [1.0, 1.0, 1.0]],
dtype=dtype)
size = torch.tensor([0.5, 0.5, 0.5], dtype=dtype)
dtype=dtype, device=device)
size = torch.tensor([0.5, 0.5, 0.5], dtype=dtype, device=device)

out = pyg_lib.ops.grid_cluster(pos, size)

assert out[0] == out[1]
assert out[0] != out[2]


@withCUDA
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
def test_grid_cluster_with_start_end(dtype: torch.dtype) -> None:
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], dtype=dtype)
size = torch.tensor([0.5, 0.5], dtype=dtype)
start = torch.tensor([0.0, 0.0], dtype=dtype)
end = torch.tensor([1.0, 1.0], dtype=dtype)
def test_grid_cluster_with_start_end(dtype: torch.dtype,
device: torch.device) -> None:
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], dtype=dtype,
device=device)
size = torch.tensor([0.5, 0.5], dtype=dtype, device=device)
start = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
end = torch.tensor([1.0, 1.0], dtype=dtype, device=device)

out = pyg_lib.ops.grid_cluster(pos, size, start, end)

assert out.shape == (3, )
assert out.dtype == torch.long


def test_grid_cluster_defaults_match_explicit() -> None:
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
size = torch.tensor([0.5, 0.5])
@withCUDA
def test_grid_cluster_defaults_match_explicit(device: torch.device) -> None:
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]], device=device)
size = torch.tensor([0.5, 0.5], device=device)

out_default = pyg_lib.ops.grid_cluster(pos, size)
out_explicit = pyg_lib.ops.grid_cluster(pos, size, start=pos.min(0).values,
end=pos.max(0).values)
out_explicit = pyg_lib.ops.grid_cluster(
pos,
size,
start=pos.min(0).values,
end=pos.max(0).values,
)

assert torch.equal(out_default, out_explicit)


@withCUDA
def test_grid_cluster_cpu_cuda_parity(device: torch.device) -> None:
pos = torch.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
size = torch.tensor([0.5, 0.5])

out_cpu = pyg_lib.ops.grid_cluster(pos, size)
out_dev = pyg_lib.ops.grid_cluster(pos.to(device), size.to(device))

assert torch.equal(out_cpu, out_dev.cpu())
Loading