Skip to content

Commit 8b5daf0

Browse files
authored
Add graclus_cluster dispatch, CPU and CUDA kernels (#593)
1 parent c1c17b3 commit 8b5daf0

File tree

6 files changed

+509
-0
lines changed

6 files changed

+509
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#include "../graclus.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor graclus_kernel(const at::Tensor& rowptr,
12+
const at::Tensor& col,
13+
const std::optional<at::Tensor>& weight) {
14+
int64_t num_nodes = rowptr.numel() - 1;
15+
auto out = at::full({num_nodes}, -1, rowptr.options());
16+
auto node_perm = at::randperm(num_nodes, rowptr.options());
17+
18+
auto rowptr_data = rowptr.data_ptr<int64_t>();
19+
auto col_data = col.data_ptr<int64_t>();
20+
auto node_perm_data = node_perm.data_ptr<int64_t>();
21+
auto out_data = out.data_ptr<int64_t>();
22+
23+
if (!weight.has_value()) {
24+
for (int64_t n = 0; n < num_nodes; n++) {
25+
auto u = node_perm_data[n];
26+
27+
if (out_data[u] >= 0)
28+
continue;
29+
30+
out_data[u] = u;
31+
32+
int64_t row_start = rowptr_data[u], row_end = rowptr_data[u + 1];
33+
34+
for (int64_t e = 0; e < row_end - row_start; e++) {
35+
auto v = col_data[row_start + e];
36+
37+
if (out_data[v] >= 0)
38+
continue;
39+
40+
out_data[u] = std::min(u, v);
41+
out_data[v] = std::min(u, v);
42+
break;
43+
}
44+
}
45+
} else {
46+
auto scalar_type = weight.value().scalar_type();
47+
AT_DISPATCH_ALL_TYPES_AND2(
48+
at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type,
49+
"graclus_cpu", [&] {
50+
auto weight_data = weight.value().data_ptr<scalar_t>();
51+
52+
for (int64_t n = 0; n < num_nodes; n++) {
53+
auto u = node_perm_data[n];
54+
55+
if (out_data[u] >= 0)
56+
continue;
57+
58+
auto v_max = u;
59+
scalar_t w_max = (scalar_t)0.;
60+
61+
for (int64_t e = rowptr_data[u]; e < rowptr_data[u + 1]; e++) {
62+
auto v = col_data[e];
63+
64+
if (out_data[v] >= 0)
65+
continue;
66+
67+
if (weight_data[e] >= w_max) {
68+
v_max = v;
69+
w_max = weight_data[e];
70+
}
71+
}
72+
73+
out_data[u] = std::min(u, v_max);
74+
out_data[v_max] = std::min(u, v_max);
75+
}
76+
});
77+
}
78+
79+
return out;
80+
}
81+
82+
} // namespace
83+
84+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
85+
m.impl(TORCH_SELECTIVE_NAME("pyg::graclus_cluster"),
86+
TORCH_FN(graclus_kernel));
87+
}
88+
89+
} // namespace ops
90+
} // namespace pyg
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#include "../graclus.h"
2+
#include "utils.cuh"
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <torch/library.h>
7+
8+
namespace pyg {
9+
namespace ops {
10+
11+
namespace {
12+
13+
#define GRACLUS_THREADS 256
14+
#define GRACLUS_BLOCKS(N) ((N) + GRACLUS_THREADS - 1) / GRACLUS_THREADS
15+
#define BLUE_P 0.53406
16+
17+
__device__ bool done_d;
18+
19+
__global__ void init_done_kernel() {
20+
done_d = true;
21+
}
22+
23+
__global__ void colorize_kernel(int64_t* out,
24+
const float* bernoulli,
25+
int64_t numel) {
26+
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
27+
if (idx < numel) {
28+
if (out[idx] < 0) {
29+
out[idx] = (int64_t)bernoulli[idx] - 2;
30+
done_d = false;
31+
}
32+
}
33+
}
34+
35+
bool colorize(at::Tensor out) {
36+
auto stream = at::cuda::getCurrentCUDAStream();
37+
init_done_kernel<<<1, 1, 0, stream>>>();
38+
39+
auto numel = out.size(0);
40+
auto props = at::full({numel}, BLUE_P, out.options().dtype(at::kFloat));
41+
auto bernoulli = props.bernoulli();
42+
43+
colorize_kernel<<<GRACLUS_BLOCKS(numel), GRACLUS_THREADS, 0, stream>>>(
44+
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
45+
46+
bool done_h;
47+
cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0,
48+
cudaMemcpyDeviceToHost);
49+
return done_h;
50+
}
51+
52+
__global__ void propose_kernel(int64_t* out,
53+
int64_t* proposal,
54+
const int64_t* rowptr,
55+
const int64_t* col,
56+
int64_t numel) {
57+
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
58+
if (idx < numel) {
59+
if (out[idx] != -1)
60+
return;
61+
62+
bool has_unmatched_neighbor = false;
63+
64+
for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) {
65+
auto v = col[i];
66+
67+
if (out[v] < 0)
68+
has_unmatched_neighbor = true;
69+
70+
if (out[v] == -2) {
71+
proposal[idx] = v;
72+
break;
73+
}
74+
}
75+
76+
if (!has_unmatched_neighbor)
77+
out[idx] = idx;
78+
}
79+
}
80+
81+
template <typename scalar_t>
82+
__global__ void weighted_propose_kernel(int64_t* out,
83+
int64_t* proposal,
84+
const int64_t* rowptr,
85+
const int64_t* col,
86+
const scalar_t* weight,
87+
int64_t numel) {
88+
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
89+
if (idx < numel) {
90+
if (out[idx] != -1)
91+
return;
92+
93+
bool has_unmatched_neighbor = false;
94+
int64_t v_max = -1;
95+
scalar_t w_max = 0;
96+
97+
for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) {
98+
auto v = col[i];
99+
100+
if (out[v] < 0)
101+
has_unmatched_neighbor = true;
102+
103+
if (out[v] == -2 && scalar_ge(weight[i], w_max)) {
104+
v_max = v;
105+
w_max = weight[i];
106+
}
107+
}
108+
109+
proposal[idx] = v_max;
110+
111+
if (!has_unmatched_neighbor)
112+
out[idx] = idx;
113+
}
114+
}
115+
116+
void propose(at::Tensor out,
117+
at::Tensor proposal,
118+
at::Tensor rowptr,
119+
at::Tensor col,
120+
const std::optional<at::Tensor>& weight) {
121+
auto stream = at::cuda::getCurrentCUDAStream();
122+
123+
if (!weight.has_value()) {
124+
propose_kernel<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>(
125+
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
126+
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
127+
C10_CUDA_KERNEL_LAUNCH_CHECK();
128+
} else {
129+
auto w = weight.value();
130+
AT_DISPATCH_FLOATING_TYPES(w.scalar_type(), "_", [&] {
131+
weighted_propose_kernel<scalar_t>
132+
<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>(
133+
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
134+
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
135+
w.data_ptr<scalar_t>(), out.numel());
136+
C10_CUDA_KERNEL_LAUNCH_CHECK();
137+
});
138+
}
139+
}
140+
141+
__global__ void respond_kernel(int64_t* out,
142+
const int64_t* proposal,
143+
const int64_t* rowptr,
144+
const int64_t* col,
145+
int64_t numel) {
146+
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
147+
if (idx < numel) {
148+
if (out[idx] != -2)
149+
return;
150+
151+
bool has_unmatched_neighbor = false;
152+
153+
for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) {
154+
auto v = col[i];
155+
156+
if (out[v] < 0)
157+
has_unmatched_neighbor = true;
158+
159+
if (out[v] == -1 && proposal[v] == idx) {
160+
int64_t m = idx < v ? idx : v;
161+
out[idx] = m;
162+
out[v] = m;
163+
break;
164+
}
165+
}
166+
167+
if (!has_unmatched_neighbor)
168+
out[idx] = idx;
169+
}
170+
}
171+
172+
template <typename scalar_t>
173+
__global__ void weighted_respond_kernel(int64_t* out,
174+
const int64_t* proposal,
175+
const int64_t* rowptr,
176+
const int64_t* col,
177+
const scalar_t* weight,
178+
int64_t numel) {
179+
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
180+
if (idx < numel) {
181+
if (out[idx] != -2)
182+
return;
183+
184+
bool has_unmatched_neighbor = false;
185+
int64_t v_max = -1;
186+
scalar_t w_max = 0;
187+
188+
for (int64_t i = rowptr[idx]; i < rowptr[idx + 1]; i++) {
189+
auto v = col[i];
190+
191+
if (out[v] < 0)
192+
has_unmatched_neighbor = true;
193+
194+
if (out[v] == -1 && proposal[v] == idx && scalar_ge(weight[i], w_max)) {
195+
v_max = v;
196+
w_max = weight[i];
197+
}
198+
}
199+
200+
if (v_max >= 0) {
201+
int64_t m = idx < v_max ? idx : v_max;
202+
out[idx] = m;
203+
out[v_max] = m;
204+
}
205+
206+
if (!has_unmatched_neighbor)
207+
out[idx] = idx;
208+
}
209+
}
210+
211+
void respond(at::Tensor out,
212+
at::Tensor proposal,
213+
at::Tensor rowptr,
214+
at::Tensor col,
215+
const std::optional<at::Tensor>& weight) {
216+
auto stream = at::cuda::getCurrentCUDAStream();
217+
218+
if (!weight.has_value()) {
219+
respond_kernel<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>(
220+
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
221+
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
222+
C10_CUDA_KERNEL_LAUNCH_CHECK();
223+
} else {
224+
auto w = weight.value();
225+
AT_DISPATCH_FLOATING_TYPES(w.scalar_type(), "_", [&] {
226+
weighted_respond_kernel<scalar_t>
227+
<<<GRACLUS_BLOCKS(out.numel()), GRACLUS_THREADS, 0, stream>>>(
228+
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
229+
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
230+
w.data_ptr<scalar_t>(), out.numel());
231+
C10_CUDA_KERNEL_LAUNCH_CHECK();
232+
});
233+
}
234+
}
235+
236+
at::Tensor graclus_cuda(const at::Tensor& rowptr,
237+
const at::Tensor& col,
238+
const std::optional<at::Tensor>& weight) {
239+
TORCH_CHECK(rowptr.is_cuda() && col.is_cuda(), "Inputs must be CUDA tensors");
240+
241+
int64_t num_nodes = rowptr.numel() - 1;
242+
auto out = at::full({num_nodes}, -1, rowptr.options());
243+
auto proposal = at::full({num_nodes}, -1, rowptr.options());
244+
245+
while (!colorize(out)) {
246+
propose(out, proposal, rowptr, col, weight);
247+
respond(out, proposal, rowptr, col, weight);
248+
}
249+
250+
return out;
251+
}
252+
253+
} // namespace
254+
255+
TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
256+
m.impl(TORCH_SELECTIVE_NAME("pyg::graclus_cluster"), TORCH_FN(graclus_cuda));
257+
}
258+
259+
} // namespace ops
260+
} // namespace pyg

pyg_lib/csrc/ops/graclus.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "graclus.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor graclus_cluster(const at::Tensor& rowptr,
10+
const at::Tensor& col,
11+
const std::optional<at::Tensor>& weight) {
12+
at::TensorArg rowptr_arg{rowptr, "rowptr", 0};
13+
at::TensorArg col_arg{col, "col", 1};
14+
at::CheckedFrom c{"graclus_cluster"};
15+
16+
at::checkAllDefined(c, {rowptr_arg, col_arg});
17+
at::checkDim(c, rowptr_arg, 1);
18+
at::checkDim(c, col_arg, 1);
19+
20+
if (weight.has_value()) {
21+
TORCH_CHECK(weight.value().dim() == 1, "weight must be 1-dimensional");
22+
TORCH_CHECK(weight.value().numel() == col.numel(),
23+
"weight must have the same number of elements as col");
24+
}
25+
26+
static auto op = c10::Dispatcher::singleton()
27+
.findSchemaOrThrow("pyg::graclus_cluster", "")
28+
.typed<decltype(graclus_cluster)>();
29+
return op.call(rowptr, col, weight);
30+
}
31+
32+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
33+
m.def(
34+
TORCH_SELECTIVE_SCHEMA("pyg::graclus_cluster(Tensor rowptr, Tensor col, "
35+
"Tensor? weight=None) -> Tensor"));
36+
}
37+
38+
} // namespace ops
39+
} // namespace pyg

0 commit comments

Comments
 (0)