|
| 1 | +#include "../nearest.h" |
| 2 | +#include "utils.cuh" |
| 3 | + |
| 4 | +#include <ATen/ATen.h> |
| 5 | +#include <ATen/cuda/CUDAContext.h> |
| 6 | +#include <torch/library.h> |
| 7 | + |
| 8 | +namespace pyg { |
| 9 | +namespace ops { |
| 10 | + |
| 11 | +namespace { |
| 12 | + |
| 13 | +#define NEAREST_THREADS 1024 |
| 14 | + |
| 15 | +template <typename scalar_t> |
| 16 | +__global__ void nearest_cuda_kernel(const scalar_t* __restrict__ x, |
| 17 | + const scalar_t* __restrict__ y, |
| 18 | + const int64_t* __restrict__ ptr_x, |
| 19 | + const int64_t* __restrict__ ptr_y, |
| 20 | + int64_t* __restrict__ out, |
| 21 | + int64_t batch_size, |
| 22 | + int64_t dim) { |
| 23 | + const int64_t thread_idx = threadIdx.x; |
| 24 | + const int64_t n_x = blockIdx.x; |
| 25 | + |
| 26 | + int64_t batch_idx = 0; |
| 27 | + for (int64_t b = 0; b < batch_size; b++) { |
| 28 | + if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) { |
| 29 | + batch_idx = b; |
| 30 | + break; |
| 31 | + } |
| 32 | + } |
| 33 | + |
| 34 | + const int64_t y_start = ptr_y[batch_idx]; |
| 35 | + const int64_t y_end = ptr_y[batch_idx + 1]; |
| 36 | + |
| 37 | + __shared__ scalar_t best_dist[NEAREST_THREADS]; |
| 38 | + __shared__ int64_t best_dist_idx[NEAREST_THREADS]; |
| 39 | + |
| 40 | + scalar_t best = (scalar_t)1e38; |
| 41 | + int64_t best_idx = y_start; |
| 42 | + for (int64_t n_y = y_start + thread_idx; n_y < y_end; |
| 43 | + n_y += NEAREST_THREADS) { |
| 44 | + scalar_t dist = 0; |
| 45 | + for (int64_t d = 0; d < dim; d++) { |
| 46 | + scalar_t diff = x[n_x * dim + d] - y[n_y * dim + d]; |
| 47 | + dist += diff * diff; |
| 48 | + } |
| 49 | + |
| 50 | + if (scalar_lt(dist, best)) { |
| 51 | + best = dist; |
| 52 | + best_idx = n_y; |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + best_dist[thread_idx] = best; |
| 57 | + best_dist_idx[thread_idx] = best_idx; |
| 58 | + |
| 59 | + for (int64_t u = 0; (1 << u) < NEAREST_THREADS; u++) { |
| 60 | + __syncthreads(); |
| 61 | + if (thread_idx < (NEAREST_THREADS >> (u + 1))) { |
| 62 | + int64_t idx_1 = (thread_idx * 2) << u; |
| 63 | + int64_t idx_2 = (thread_idx * 2 + 1) << u; |
| 64 | + if (scalar_gt(best_dist[idx_1], best_dist[idx_2])) { |
| 65 | + best_dist[idx_1] = best_dist[idx_2]; |
| 66 | + best_dist_idx[idx_1] = best_dist_idx[idx_2]; |
| 67 | + } |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + __syncthreads(); |
| 72 | + if (thread_idx == 0) { |
| 73 | + out[n_x] = best_dist_idx[0]; |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +at::Tensor nearest_cuda(const at::Tensor& x, |
| 78 | + const at::Tensor& y, |
| 79 | + const std::optional<at::Tensor>& ptr_x, |
| 80 | + const std::optional<at::Tensor>& ptr_y) { |
| 81 | + TORCH_CHECK(x.is_cuda() && y.is_cuda(), "Inputs must be CUDA tensors"); |
| 82 | + TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), |
| 83 | + "Inputs must be contiguous"); |
| 84 | + |
| 85 | + std::optional<at::Tensor> ptr_x_v = ptr_x; |
| 86 | + std::optional<at::Tensor> ptr_y_v = ptr_y; |
| 87 | + |
| 88 | + if (!ptr_x_v.has_value()) |
| 89 | + ptr_x_v = |
| 90 | + at::arange(0, x.size(0) + 1, x.size(0), x.options().dtype(at::kLong)); |
| 91 | + if (!ptr_y_v.has_value()) |
| 92 | + ptr_y_v = |
| 93 | + at::arange(0, y.size(0) + 1, y.size(0), y.options().dtype(at::kLong)); |
| 94 | + |
| 95 | + auto out = at::empty({x.size(0)}, ptr_x_v.value().options()); |
| 96 | + |
| 97 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 98 | + AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_cuda", [&] { |
| 99 | + nearest_cuda_kernel<scalar_t><<<x.size(0), NEAREST_THREADS, 0, stream>>>( |
| 100 | + x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), |
| 101 | + ptr_x_v.value().data_ptr<int64_t>(), |
| 102 | + ptr_y_v.value().data_ptr<int64_t>(), out.data_ptr<int64_t>(), |
| 103 | + ptr_x_v.value().size(0) - 1, x.size(1)); |
| 104 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 105 | + }); |
| 106 | + |
| 107 | + return out; |
| 108 | +} |
| 109 | + |
| 110 | +} // namespace |
| 111 | + |
| 112 | +TORCH_LIBRARY_IMPL(pyg, CUDA, m) { |
| 113 | + m.impl(TORCH_SELECTIVE_NAME("pyg::nearest"), TORCH_FN(nearest_cuda)); |
| 114 | +} |
| 115 | + |
| 116 | +} // namespace ops |
| 117 | +} // namespace pyg |
0 commit comments