|
| 1 | +#include "../graclus.h" |
| 2 | +#include "utils.cuh" |
| 3 | + |
| 4 | +#include <ATen/ATen.h> |
| 5 | +#include <ATen/cuda/CUDAContext.h> |
| 6 | +#include <torch/library.h> |
| 7 | + |
| 8 | +namespace pyg { |
| 9 | +namespace ops { |
| 10 | + |
| 11 | +namespace { |
| 12 | + |
| 13 | +#define GRACLUS_THREADS 256 |
| 14 | +#define GRACLUS_BLOCKS(N) ((N) + GRACLUS_THREADS - 1) / GRACLUS_THREADS |
| 15 | +#define BLUE_P 0.53406 |
| 16 | + |
| 17 | +__device__ bool done_d; |
| 18 | + |
| 19 | +__global__ void init_done_kernel() { |
| 20 | + done_d = true; |
| 21 | +} |
| 22 | + |
| 23 | +__global__ void colorize_kernel(int64_t* out, |
| 24 | + const float* bernoulli, |
| 25 | + int64_t numel) { |
| 26 | + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 27 | + if (idx < numel) { |
| 28 | + if (out[idx] < 0) { |
| 29 | + out[idx] = (int64_t)bernoulli[idx] - 2; |
| 30 | + done_d = false; |
| 31 | + } |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +bool colorize(at::Tensor out) { |
| 36 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 37 | + init_done_kernel<<<1, 1, 0, stream>>>(); |
| 38 | + |
| 39 | + auto numel = out.size(0); |
| 40 | + auto props = at::full({numel}, BLUE_P, out.options().dtype(at::kFloat)); |
| 41 | + auto bernoulli = props.bernoulli(); |
| 42 | + |
| 43 | + colorize_kernel<<<GRACLUS_BLOCKS(numel), GRACLUS_THREADS, 0, stream>>>( |
| 44 | + out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel); |
| 45 | + |
| 46 | + bool done_h; |
| 47 | + cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, |
| 48 | + cudaMemcpyDeviceToHost); |
| 49 | + return done_h; |
| 50 | +} |
| 51 | + |
| 52 | +__global__ void propose_kernel(int64_t* out, |
| 53 | + int64_t* proposal, |
| 54 | + const int64_t* rowptr, |
| 55 | + const int64_t* col, |
| 56 | + int64_t numel) { |
| 57 | + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 58 | + if (idx < numel) { |
| 59 | + if (out[idx] != -1) |
| 60 | + return; |
| 61 | + |
| 62 | + bool has_unmatched_neighbor = false; |
| 63 | + |
| 64 | + for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) { |
| 65 | + auto v = col[i]; |
| 66 | + |
| 67 | + if (out[v] < 0) |
| 68 | + has_unmatched_neighbor = true; |
| 69 | + |
| 70 | + if (out[v] == -2) { |
| 71 | + proposal[idx] = v; |
| 72 | + break; |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + if (!has_unmatched_neighbor) |
| 77 | + out[idx] = idx; |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +template <typename scalar_t> |
| 82 | +__global__ void weighted_propose_kernel(int64_t* out, |
| 83 | + int64_t* proposal, |
| 84 | + const int64_t* rowptr, |
| 85 | + const int64_t* col, |
| 86 | + const scalar_t* weight, |
| 87 | + int64_t numel) { |
| 88 | + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 89 | + if (idx < numel) { |
| 90 | + if (out[idx] != -1) |
| 91 | + return; |
| 92 | + |
| 93 | + bool has_unmatched_neighbor = false; |
| 94 | + int64_t v_max = -1; |
| 95 | + scalar_t w_max = 0; |
| 96 | + |
| 97 | + for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) { |
| 98 | + auto v = col[i]; |
| 99 | + |
| 100 | + if (out[v] < 0) |
| 101 | + has_unmatched_neighbor = true; |
| 102 | + |
| 103 | + if (out[v] == -2 && scalar_ge(weight[i], w_max)) { |
| 104 | + v_max = v; |
| 105 | + w_max = weight[i]; |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + proposal[idx] = v_max; |
| 110 | + |
| 111 | + if (!has_unmatched_neighbor) |
| 112 | + out[idx] = idx; |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +void propose(at::Tensor out, |
| 117 | + at::Tensor proposal, |
| 118 | + at::Tensor rowptr, |
| 119 | + at::Tensor col, |
| 120 | + const std::optional<at::Tensor>& weight) { |
| 121 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 122 | + |
| 123 | + if (!weight.has_value()) { |
| 124 | + propose_kernel<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>( |
| 125 | + out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), |
| 126 | + rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel()); |
| 127 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 128 | + } else { |
| 129 | + auto w = weight.value(); |
| 130 | + AT_DISPATCH_FLOATING_TYPES(w.scalar_type(), "_", [&] { |
| 131 | + weighted_propose_kernel<scalar_t> |
| 132 | + <<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>( |
| 133 | + out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), |
| 134 | + rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), |
| 135 | + w.data_ptr<scalar_t>(), out.numel()); |
| 136 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 137 | + }); |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +__global__ void respond_kernel(int64_t* out, |
| 142 | + const int64_t* proposal, |
| 143 | + const int64_t* rowptr, |
| 144 | + const int64_t* col, |
| 145 | + int64_t numel) { |
| 146 | + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 147 | + if (idx < numel) { |
| 148 | + if (out[idx] != -2) |
| 149 | + return; |
| 150 | + |
| 151 | + bool has_unmatched_neighbor = false; |
| 152 | + |
| 153 | + for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) { |
| 154 | + auto v = col[i]; |
| 155 | + |
| 156 | + if (out[v] < 0) |
| 157 | + has_unmatched_neighbor = true; |
| 158 | + |
| 159 | + if (out[v] == -1 && proposal[v] == idx) { |
| 160 | + int64_t m = idx < v ? idx : v; |
| 161 | + out[idx] = m; |
| 162 | + out[v] = m; |
| 163 | + break; |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + if (!has_unmatched_neighbor) |
| 168 | + out[idx] = idx; |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +template <typename scalar_t> |
| 173 | +__global__ void weighted_respond_kernel(int64_t* out, |
| 174 | + const int64_t* proposal, |
| 175 | + const int64_t* rowptr, |
| 176 | + const int64_t* col, |
| 177 | + const scalar_t* weight, |
| 178 | + int64_t numel) { |
| 179 | + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 180 | + if (idx < numel) { |
| 181 | + if (out[idx] != -2) |
| 182 | + return; |
| 183 | + |
| 184 | + bool has_unmatched_neighbor = false; |
| 185 | + int64_t v_max = -1; |
| 186 | + scalar_t w_max = 0; |
| 187 | + |
| 188 | + for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) { |
| 189 | + auto v = col[i]; |
| 190 | + |
| 191 | + if (out[v] < 0) |
| 192 | + has_unmatched_neighbor = true; |
| 193 | + |
| 194 | + if (out[v] == -1 && proposal[v] == idx && scalar_ge(weight[i], w_max)) { |
| 195 | + v_max = v; |
| 196 | + w_max = weight[i]; |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + if (v_max >= 0) { |
| 201 | + int64_t m = idx < v_max ? idx : v_max; |
| 202 | + out[idx] = m; |
| 203 | + out[v_max] = m; |
| 204 | + } |
| 205 | + |
| 206 | + if (!has_unmatched_neighbor) |
| 207 | + out[idx] = idx; |
| 208 | + } |
| 209 | +} |
| 210 | + |
| 211 | +void respond(at::Tensor out, |
| 212 | + at::Tensor proposal, |
| 213 | + at::Tensor rowptr, |
| 214 | + at::Tensor col, |
| 215 | + const std::optional<at::Tensor>& weight) { |
| 216 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 217 | + |
| 218 | + if (!weight.has_value()) { |
| 219 | + respond_kernel<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>( |
| 220 | + out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), |
| 221 | + rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel()); |
| 222 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 223 | + } else { |
| 224 | + auto w = weight.value(); |
| 225 | + AT_DISPATCH_FLOATING_TYPES(w.scalar_type(), "_", [&] { |
| 226 | + weighted_respond_kernel<scalar_t> |
| 227 | + <<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>( |
| 228 | + out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), |
| 229 | + rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), |
| 230 | + w.data_ptr<scalar_t>(), out.numel()); |
| 231 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 232 | + }); |
| 233 | + } |
| 234 | +} |
| 235 | + |
| 236 | +at::Tensor graclus_cuda(const at::Tensor& rowptr, |
| 237 | + const at::Tensor& col, |
| 238 | + const std::optional<at::Tensor>& weight) { |
| 239 | + TORCH_CHECK(rowptr.is_cuda() && col.is_cuda(), "Inputs must be CUDA tensors"); |
| 240 | + |
| 241 | + int64_t num_nodes = rowptr.numel() - 1; |
| 242 | + auto out = at::full({num_nodes}, -1, rowptr.options()); |
| 243 | + auto proposal = at::full({num_nodes}, -1, rowptr.options()); |
| 244 | + |
| 245 | + while (!colorize(out)) { |
| 246 | + propose(out, proposal, rowptr, col, weight); |
| 247 | + respond(out, proposal, rowptr, col, weight); |
| 248 | + } |
| 249 | + |
| 250 | + return out; |
| 251 | +} |
| 252 | + |
| 253 | +} // namespace |
| 254 | + |
| 255 | +TORCH_LIBRARY_IMPL(pyg, CUDA, m) { |
| 256 | + m.impl(TORCH_SELECTIVE_NAME("pyg::graclus_cluster"), TORCH_FN(graclus_cuda)); |
| 257 | +} |
| 258 | + |
| 259 | +} // namespace ops |
| 260 | +} // namespace pyg |
0 commit comments