|
| 1 | +#include "../fps.h" |
| 2 | + |
| 3 | +#include <ATen/ATen.h> |
| 4 | +#include <ATen/cuda/CUDAContext.h> |
| 5 | +#include <torch/library.h> |
| 6 | + |
| 7 | +namespace pyg { |
| 8 | +namespace ops { |
| 9 | + |
| 10 | +namespace { |
| 11 | + |
| 12 | +#define FPS_THREADS 256 |
| 13 | + |
| 14 | +// Explicit non-template comparison/min functions to avoid NVCC ambiguous |
| 15 | +// operator overload errors from c10::SymInt (error #3343). |
| 16 | +__device__ __forceinline__ bool scalar_gt(float a, float b) { |
| 17 | + return a > b; |
| 18 | +} |
| 19 | +__device__ __forceinline__ bool scalar_gt(double a, double b) { |
| 20 | + return a > b; |
| 21 | +} |
| 22 | +__device__ __forceinline__ bool scalar_lt(float a, float b) { |
| 23 | + return a < b; |
| 24 | +} |
| 25 | +__device__ __forceinline__ bool scalar_lt(double a, double b) { |
| 26 | + return a < b; |
| 27 | +} |
| 28 | +__device__ __forceinline__ float scalar_min(float a, float b) { |
| 29 | + return fminf(a, b); |
| 30 | +} |
| 31 | +__device__ __forceinline__ double scalar_min(double a, double b) { |
| 32 | + return fmin(a, b); |
| 33 | +} |
| 34 | + |
| 35 | +template <typename scalar_t> |
| 36 | +__global__ void fps_cuda_kernel(const scalar_t* src, |
| 37 | + const int64_t* ptr, |
| 38 | + const int64_t* out_ptr, |
| 39 | + const int64_t* start, |
| 40 | + scalar_t* dist, |
| 41 | + int64_t* out, |
| 42 | + int64_t dim) { |
| 43 | + const int64_t thread_idx = threadIdx.x; |
| 44 | + const int64_t batch_idx = blockIdx.x; |
| 45 | + |
| 46 | + const int64_t start_idx = ptr[batch_idx]; |
| 47 | + const int64_t end_idx = ptr[batch_idx + 1]; |
| 48 | + |
| 49 | + __shared__ scalar_t best_dist[FPS_THREADS]; |
| 50 | + __shared__ int64_t best_dist_idx[FPS_THREADS]; |
| 51 | + |
| 52 | + if (thread_idx == 0) { |
| 53 | + out[out_ptr[batch_idx]] = start_idx + start[batch_idx]; |
| 54 | + } |
| 55 | + |
| 56 | + for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) { |
| 57 | + __syncthreads(); |
| 58 | + int64_t old = out[m - 1]; |
| 59 | + |
| 60 | + scalar_t best = (scalar_t)-1.; |
| 61 | + int64_t best_idx = 0; |
| 62 | + |
| 63 | + for (int64_t n = start_idx + thread_idx; n < end_idx; n += FPS_THREADS) { |
| 64 | + scalar_t tmp, dd = (scalar_t)0.; |
| 65 | + for (int64_t d = 0; d < dim; d++) { |
| 66 | + tmp = src[dim * old + d] - src[dim * n + d]; |
| 67 | + dd += tmp * tmp; |
| 68 | + } |
| 69 | + dd = scalar_min(dist[n], dd); |
| 70 | + dist[n] = dd; |
| 71 | + if (scalar_gt(dd, best)) { |
| 72 | + best = dd; |
| 73 | + best_idx = n; |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + best_dist[thread_idx] = best; |
| 78 | + best_dist_idx[thread_idx] = best_idx; |
| 79 | + |
| 80 | + for (int64_t i = 1; i < FPS_THREADS; i *= 2) { |
| 81 | + __syncthreads(); |
| 82 | + if ((thread_idx + i) < FPS_THREADS && |
| 83 | + scalar_lt(best_dist[thread_idx], best_dist[thread_idx + i])) { |
| 84 | + best_dist[thread_idx] = best_dist[thread_idx + i]; |
| 85 | + best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i]; |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + __syncthreads(); |
| 90 | + if (thread_idx == 0) { |
| 91 | + out[m] = best_dist_idx[0]; |
| 92 | + } |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +at::Tensor fps_cuda(const at::Tensor& src, |
| 97 | + const at::Tensor& ptr, |
| 98 | + double ratio, |
| 99 | + bool random_start) { |
| 100 | + TORCH_CHECK(src.is_cuda(), "src must be a CUDA tensor"); |
| 101 | + TORCH_CHECK(src.is_contiguous(), "src must be contiguous"); |
| 102 | + TORCH_CHECK(ptr.is_cuda(), "ptr must be a CUDA tensor"); |
| 103 | + |
| 104 | + int64_t batch_size = ptr.numel() - 1; |
| 105 | + int64_t D = src.size(1); |
| 106 | + |
| 107 | + auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); |
| 108 | + auto out_ptr = deg.to(at::kFloat) * ratio; |
| 109 | + out_ptr = out_ptr.ceil().to(at::kLong).cumsum(0); |
| 110 | + out_ptr = at::cat({at::zeros({1}, ptr.options()), out_ptr}, 0); |
| 111 | + |
| 112 | + at::Tensor start; |
| 113 | + if (random_start) { |
| 114 | + start = at::rand({batch_size}, src.options()); |
| 115 | + start = (start * deg.to(at::kFloat)).to(at::kLong); |
| 116 | + } else { |
| 117 | + start = at::zeros({batch_size}, ptr.options()); |
| 118 | + } |
| 119 | + |
| 120 | + auto dist = at::full({src.size(0)}, 5e4, src.options()); |
| 121 | + |
| 122 | + int64_t out_total; |
| 123 | + cudaMemcpy(&out_total, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t), |
| 124 | + cudaMemcpyDeviceToHost); |
| 125 | + auto out = at::empty({out_total}, out_ptr.options()); |
| 126 | + |
| 127 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 128 | + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_cuda", [&] { |
| 129 | + fps_cuda_kernel<scalar_t><<<batch_size, FPS_THREADS, 0, stream>>>( |
| 130 | + src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(), |
| 131 | + out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(), |
| 132 | + dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), D); |
| 133 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 134 | + }); |
| 135 | + |
| 136 | + return out; |
| 137 | +} |
| 138 | + |
| 139 | +} // namespace |
| 140 | + |
| 141 | +TORCH_LIBRARY_IMPL(pyg, CUDA, m) { |
| 142 | + m.impl(TORCH_SELECTIVE_NAME("pyg::fps"), TORCH_FN(fps_cuda)); |
| 143 | +} |
| 144 | + |
| 145 | +} // namespace ops |
| 146 | +} // namespace pyg |
0 commit comments