Skip to content

Commit 7a1cf27

Browse files
authored
[geometric] Optimize graph sample speed (#47531) (#47548)
1 parent 61953b9 commit 7a1cf27

File tree

1 file changed

+36
-39
lines changed

1 file changed

+36
-39
lines changed

paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct MaxFunctor {
5858
}
5959
};
6060

61-
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
61+
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
6262
__global__ void SampleKernel(const uint64_t rand_seed,
6363
int k,
6464
const int64_t num_nodes,
@@ -71,52 +71,51 @@ __global__ void SampleKernel(const uint64_t rand_seed,
7171
T* output_eids,
7272
int* output_ptr,
7373
bool return_eids) {
74-
assert(blockDim.x == WARP_SIZE);
75-
assert(blockDim.y == BLOCK_WARPS);
74+
assert(blockDim.x == CTA_SIZE);
7675

7776
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
7877
const int64_t last_row =
7978
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_nodes);
8079
#ifdef PADDLE_WITH_HIP
8180
hiprandState rng;
8281
hiprand_init(rand_seed * gridDim.x + blockIdx.x,
83-
threadIdx.y * WARP_SIZE + threadIdx.x,
82+
threadIdx.y * CTA_SIZE + threadIdx.x,
8483
0,
8584
&rng);
8685
#else
87-
curandState rng;
86+
curandStatePhilox4_32_10_t rng;
8887
curand_init(rand_seed * gridDim.x + blockIdx.x,
89-
threadIdx.y * WARP_SIZE + threadIdx.x,
88+
threadIdx.y * CTA_SIZE + threadIdx.x,
9089
0,
9190
&rng);
9291
#endif
9392

9493
while (out_row < last_row) {
9594
T node = nodes[out_row];
9695
if (node > len_col_ptr - 1) {
97-
out_row += BLOCK_WARPS;
96+
out_row += BLOCK_CTAS;
9897
continue;
9998
}
10099
T in_row_start = col_ptr[node];
101100
int deg = col_ptr[node + 1] - in_row_start;
102101
int out_row_start = output_ptr[out_row];
103102

104103
if (deg <= k) {
105-
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
104+
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
106105
output[out_row_start + idx] = row[in_row_start + idx];
107106
if (return_eids) {
108107
output_eids[out_row_start + idx] = eids[in_row_start + idx];
109108
}
110109
}
111110
} else {
112-
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
111+
for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
113112
output[out_row_start + idx] = idx;
114113
}
115114
#ifdef PADDLE_WITH_CUDA
116-
__syncwarp();
115+
__syncthreads();
117116
#endif
118117

119-
for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) {
118+
for (int idx = k + threadIdx.x; idx < deg; idx += CTA_SIZE) {
120119
#ifdef PADDLE_WITH_HIP
121120
const int num = hiprand(&rng) % (idx + 1);
122121
#else
@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
129128
}
130129
}
131130
#ifdef PADDLE_WITH_CUDA
132-
__syncwarp();
131+
__syncthreads();
133132
#endif
134133

135-
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
134+
for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
136135
T perm_idx = output[out_row_start + idx] + in_row_start;
137136
output[out_row_start + idx] = row[perm_idx];
138137
if (return_eids) {
@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
141140
}
142141
}
143142

144-
out_row += BLOCK_WARPS;
143+
out_row += BLOCK_CTAS;
145144
}
146145
}
147146

@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
181180
thrust::exclusive_scan(
182181
output_count, output_count + bs, output_ptr.begin(), 0);
183182

184-
constexpr int WARP_SIZE = 32;
185-
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
186-
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
187-
const dim3 block(WARP_SIZE, BLOCK_WARPS);
183+
constexpr int CTA_SIZE = 128;
184+
constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
185+
constexpr int TILE_SIZE = BLOCK_CTAS;
186+
const dim3 block(CTA_SIZE, BLOCK_CTAS);
188187
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
189-
SampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
188+
SampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
190189
<<<grid, block, 0, dev_ctx.stream()>>>(
191190
0,
192191
sample_size,
@@ -202,16 +201,15 @@ void SampleNeighbors(const Context& dev_ctx,
202201
return_eids);
203202
}
204203

205-
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
204+
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
206205
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
207206
int k,
208207
const int64_t num_rows,
209208
const int64_t len_col_ptr,
210209
const T* in_rows,
211210
T* src,
212211
const T* dst_count) {
213-
assert(blockDim.x == WARP_SIZE);
214-
assert(blockDim.y == BLOCK_WARPS);
212+
assert(blockDim.x == CTA_SIZE);
215213

216214
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
217215
const int64_t last_row =
@@ -221,15 +219,15 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
221219
hiprand_init(
222220
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
223221
#else
224-
curandState rng;
222+
curandStatePhilox4_32_10_t rng;
225223
curand_init(
226224
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
227225
#endif
228226

229227
while (out_row < last_row) {
230228
const T row = in_rows[out_row];
231229
if (row > len_col_ptr - 1) {
232-
out_row += BLOCK_WARPS;
230+
out_row += BLOCK_CTAS;
233231
continue;
234232
}
235233
const T in_row_start = dst_count[row];
@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
241239
} else {
242240
split = deg - k;
243241
}
244-
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) {
242+
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) {
245243
#ifdef PADDLE_WITH_HIP
246244
const int num = hiprand(&rng) % (idx + 1);
247245
#else
@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
254252
src[in_row_start + idx])));
255253
}
256254
#ifdef PADDLE_WITH_CUDA
257-
__syncwarp();
255+
__syncthreads();
258256
#endif
259257
}
260-
out_row += BLOCK_WARPS;
258+
out_row += BLOCK_CTAS;
261259
}
262260
}
263261

264-
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
262+
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
265263
__global__ void GatherEdge(int k,
266264
int64_t num_rows,
267265
const T* in_rows,
@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
273271
int* output_ptr,
274272
T* perm_data,
275273
bool return_eids) {
276-
assert(blockDim.x == WARP_SIZE);
277-
assert(blockDim.y == BLOCK_WARPS);
274+
assert(blockDim.x == CTA_SIZE);
278275

279276
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
280277
const int64_t last_row =
@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
287284
const T out_row_start = output_ptr[out_row];
288285

289286
if (deg <= k) {
290-
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
287+
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
291288
outputs[out_row_start + idx] = src[in_row_start + idx];
292289
if (return_eids) {
293290
output_eids[out_row_start + idx] = eids[in_row_start + idx];
@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
304301
end = deg;
305302
}
306303

307-
for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) {
304+
for (int idx = begin + threadIdx.x; idx < end; idx += CTA_SIZE) {
308305
outputs[out_row_start + idx - begin] =
309306
src[perm_data[in_row_start + idx]];
310307
if (return_eids) {
@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
313310
}
314311
}
315312
}
316-
out_row += BLOCK_WARPS;
313+
out_row += BLOCK_CTAS;
317314
}
318315
}
319316

@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
337334
thrust::exclusive_scan(
338335
output_count, output_count + bs, output_ptr.begin(), 0);
339336

340-
constexpr int WARP_SIZE = 32;
341-
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
342-
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
343-
const dim3 block(WARP_SIZE, BLOCK_WARPS);
337+
constexpr int CTA_SIZE = 128;
338+
constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
339+
constexpr int TILE_SIZE = BLOCK_CTAS;
340+
const dim3 block(CTA_SIZE, BLOCK_CTAS);
344341
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
345342

346-
FisherYatesSampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
343+
FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
347344
<<<grid, block, 0, dev_ctx.stream()>>>(0,
348345
sample_size,
349346
bs,
@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
352349
perm_data,
353350
col_ptr);
354351

355-
GatherEdge<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
352+
GatherEdge<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
356353
<<<grid, block, 0, dev_ctx.stream()>>>(
357354
sample_size,
358355
bs,

0 commit comments

Comments
 (0)