@@ -58,7 +58,7 @@ struct MaxFunctor {
58
58
}
59
59
};
60
60
61
- template <typename T, int WARP_SIZE , int BLOCK_WARPS , int TILE_SIZE>
61
+ template <typename T, int CTA_SIZE , int BLOCK_CTAS , int TILE_SIZE>
62
62
__global__ void SampleKernel (const uint64_t rand_seed,
63
63
int k,
64
64
const int64_t num_nodes,
@@ -71,52 +71,51 @@ __global__ void SampleKernel(const uint64_t rand_seed,
71
71
T* output_eids,
72
72
int * output_ptr,
73
73
bool return_eids) {
74
- assert (blockDim .x == WARP_SIZE);
75
- assert (blockDim .y == BLOCK_WARPS);
74
+ assert (blockDim .x == CTA_SIZE);
76
75
77
76
int64_t out_row = blockIdx .x * TILE_SIZE + threadIdx .y ;
78
77
const int64_t last_row =
79
78
min (static_cast <int64_t >(blockIdx .x + 1 ) * TILE_SIZE, num_nodes);
80
79
#ifdef PADDLE_WITH_HIP
81
80
hiprandState rng;
82
81
hiprand_init (rand_seed * gridDim .x + blockIdx .x ,
83
- threadIdx .y * WARP_SIZE + threadIdx .x ,
82
+ threadIdx .y * CTA_SIZE + threadIdx .x ,
84
83
0 ,
85
84
&rng);
86
85
#else
87
- curandState rng;
86
+ curandStatePhilox4_32_10_t rng;
88
87
curand_init (rand_seed * gridDim .x + blockIdx .x ,
89
- threadIdx .y * WARP_SIZE + threadIdx .x ,
88
+ threadIdx .y * CTA_SIZE + threadIdx .x ,
90
89
0 ,
91
90
&rng);
92
91
#endif
93
92
94
93
while (out_row < last_row) {
95
94
T node = nodes[out_row];
96
95
if (node > len_col_ptr - 1 ) {
97
- out_row += BLOCK_WARPS ;
96
+ out_row += BLOCK_CTAS ;
98
97
continue ;
99
98
}
100
99
T in_row_start = col_ptr[node];
101
100
int deg = col_ptr[node + 1 ] - in_row_start;
102
101
int out_row_start = output_ptr[out_row];
103
102
104
103
if (deg <= k) {
105
- for (int idx = threadIdx .x ; idx < deg; idx += WARP_SIZE ) {
104
+ for (int idx = threadIdx .x ; idx < deg; idx += CTA_SIZE ) {
106
105
output[out_row_start + idx] = row[in_row_start + idx];
107
106
if (return_eids) {
108
107
output_eids[out_row_start + idx] = eids[in_row_start + idx];
109
108
}
110
109
}
111
110
} else {
112
- for (int idx = threadIdx .x ; idx < k; idx += WARP_SIZE ) {
111
+ for (int idx = threadIdx .x ; idx < k; idx += CTA_SIZE ) {
113
112
output[out_row_start + idx] = idx;
114
113
}
115
114
#ifdef PADDLE_WITH_CUDA
116
- __syncwarp ();
115
+ __syncthreads ();
117
116
#endif
118
117
119
- for (int idx = k + threadIdx .x ; idx < deg; idx += WARP_SIZE ) {
118
+ for (int idx = k + threadIdx .x ; idx < deg; idx += CTA_SIZE ) {
120
119
#ifdef PADDLE_WITH_HIP
121
120
const int num = hiprand (&rng) % (idx + 1 );
122
121
#else
@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
129
128
}
130
129
}
131
130
#ifdef PADDLE_WITH_CUDA
132
- __syncwarp ();
131
+ __syncthreads ();
133
132
#endif
134
133
135
- for (int idx = threadIdx .x ; idx < k; idx += WARP_SIZE ) {
134
+ for (int idx = threadIdx .x ; idx < k; idx += CTA_SIZE ) {
136
135
T perm_idx = output[out_row_start + idx] + in_row_start;
137
136
output[out_row_start + idx] = row[perm_idx];
138
137
if (return_eids) {
@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
141
140
}
142
141
}
143
142
144
- out_row += BLOCK_WARPS ;
143
+ out_row += BLOCK_CTAS ;
145
144
}
146
145
}
147
146
@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
181
180
thrust::exclusive_scan (
182
181
output_count, output_count + bs, output_ptr.begin (), 0 );
183
182
184
- constexpr int WARP_SIZE = 32 ;
185
- constexpr int BLOCK_WARPS = 128 / WARP_SIZE ;
186
- constexpr int TILE_SIZE = BLOCK_WARPS * 16 ;
187
- const dim3 block (WARP_SIZE, BLOCK_WARPS );
183
+ constexpr int CTA_SIZE = 128 ;
184
+ constexpr int BLOCK_CTAS = 128 / CTA_SIZE ;
185
+ constexpr int TILE_SIZE = BLOCK_CTAS ;
186
+ const dim3 block (CTA_SIZE, BLOCK_CTAS );
188
187
const dim3 grid ((bs + TILE_SIZE - 1 ) / TILE_SIZE);
189
- SampleKernel<T, WARP_SIZE, BLOCK_WARPS , TILE_SIZE>
188
+ SampleKernel<T, CTA_SIZE, BLOCK_CTAS , TILE_SIZE>
190
189
<<<grid, block, 0 , dev_ctx.stream()>>> (
191
190
0 ,
192
191
sample_size,
@@ -202,16 +201,15 @@ void SampleNeighbors(const Context& dev_ctx,
202
201
return_eids);
203
202
}
204
203
205
- template <typename T, int WARP_SIZE , int BLOCK_WARPS , int TILE_SIZE>
204
+ template <typename T, int CTA_SIZE , int BLOCK_CTAS , int TILE_SIZE>
206
205
__global__ void FisherYatesSampleKernel (const uint64_t rand_seed,
207
206
int k,
208
207
const int64_t num_rows,
209
208
const int64_t len_col_ptr,
210
209
const T* in_rows,
211
210
T* src,
212
211
const T* dst_count) {
213
- assert (blockDim .x == WARP_SIZE);
214
- assert (blockDim .y == BLOCK_WARPS);
212
+ assert (blockDim .x == CTA_SIZE);
215
213
216
214
int64_t out_row = blockIdx .x * TILE_SIZE + threadIdx .y ;
217
215
const int64_t last_row =
@@ -221,15 +219,15 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
221
219
hiprand_init (
222
220
rand_seed * gridDim .x + blockIdx .x , threadIdx .y + threadIdx .x , 0 , &rng);
223
221
#else
224
- curandState rng;
222
+ curandStatePhilox4_32_10_t rng;
225
223
curand_init (
226
224
rand_seed * gridDim .x + blockIdx .x , threadIdx .y + threadIdx .x , 0 , &rng);
227
225
#endif
228
226
229
227
while (out_row < last_row) {
230
228
const T row = in_rows[out_row];
231
229
if (row > len_col_ptr - 1 ) {
232
- out_row += BLOCK_WARPS ;
230
+ out_row += BLOCK_CTAS ;
233
231
continue ;
234
232
}
235
233
const T in_row_start = dst_count[row];
@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
241
239
} else {
242
240
split = deg - k;
243
241
}
244
- for (int idx = split + threadIdx .x ; idx <= deg - 1 ; idx += WARP_SIZE ) {
242
+ for (int idx = split + threadIdx .x ; idx <= deg - 1 ; idx += CTA_SIZE ) {
245
243
#ifdef PADDLE_WITH_HIP
246
244
const int num = hiprand (&rng) % (idx + 1 );
247
245
#else
@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
254
252
src[in_row_start + idx])));
255
253
}
256
254
#ifdef PADDLE_WITH_CUDA
257
- __syncwarp ();
255
+ __syncthreads ();
258
256
#endif
259
257
}
260
- out_row += BLOCK_WARPS ;
258
+ out_row += BLOCK_CTAS ;
261
259
}
262
260
}
263
261
264
- template <typename T, int WARP_SIZE , int BLOCK_WARPS , int TILE_SIZE>
262
+ template <typename T, int CTA_SIZE , int BLOCK_CTAS , int TILE_SIZE>
265
263
__global__ void GatherEdge (int k,
266
264
int64_t num_rows,
267
265
const T* in_rows,
@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
273
271
int * output_ptr,
274
272
T* perm_data,
275
273
bool return_eids) {
276
- assert (blockDim .x == WARP_SIZE);
277
- assert (blockDim .y == BLOCK_WARPS);
274
+ assert (blockDim .x == CTA_SIZE);
278
275
279
276
int64_t out_row = blockIdx .x * TILE_SIZE + threadIdx .y ;
280
277
const int64_t last_row =
@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
287
284
const T out_row_start = output_ptr[out_row];
288
285
289
286
if (deg <= k) {
290
- for (int idx = threadIdx .x ; idx < deg; idx += WARP_SIZE ) {
287
+ for (int idx = threadIdx .x ; idx < deg; idx += CTA_SIZE ) {
291
288
outputs[out_row_start + idx] = src[in_row_start + idx];
292
289
if (return_eids) {
293
290
output_eids[out_row_start + idx] = eids[in_row_start + idx];
@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
304
301
end = deg;
305
302
}
306
303
307
- for (int idx = begin + threadIdx .x ; idx < end; idx += WARP_SIZE ) {
304
+ for (int idx = begin + threadIdx .x ; idx < end; idx += CTA_SIZE ) {
308
305
outputs[out_row_start + idx - begin] =
309
306
src[perm_data[in_row_start + idx]];
310
307
if (return_eids) {
@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
313
310
}
314
311
}
315
312
}
316
- out_row += BLOCK_WARPS ;
313
+ out_row += BLOCK_CTAS ;
317
314
}
318
315
}
319
316
@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
337
334
thrust::exclusive_scan (
338
335
output_count, output_count + bs, output_ptr.begin (), 0 );
339
336
340
- constexpr int WARP_SIZE = 32 ;
341
- constexpr int BLOCK_WARPS = 128 / WARP_SIZE ;
342
- constexpr int TILE_SIZE = BLOCK_WARPS * 16 ;
343
- const dim3 block (WARP_SIZE, BLOCK_WARPS );
337
+ constexpr int CTA_SIZE = 128 ;
338
+ constexpr int BLOCK_CTAS = 128 / CTA_SIZE ;
339
+ constexpr int TILE_SIZE = BLOCK_CTAS ;
340
+ const dim3 block (CTA_SIZE, BLOCK_CTAS );
344
341
const dim3 grid ((bs + TILE_SIZE - 1 ) / TILE_SIZE);
345
342
346
- FisherYatesSampleKernel<T, WARP_SIZE, BLOCK_WARPS , TILE_SIZE>
343
+ FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS , TILE_SIZE>
347
344
<<<grid, block, 0 , dev_ctx.stream()>>> (0 ,
348
345
sample_size,
349
346
bs,
@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
352
349
perm_data,
353
350
col_ptr);
354
351
355
- GatherEdge<T, WARP_SIZE, BLOCK_WARPS , TILE_SIZE>
352
+ GatherEdge<T, CTA_SIZE, BLOCK_CTAS , TILE_SIZE>
356
353
<<<grid, block, 0 , dev_ctx.stream()>>> (
357
354
sample_size,
358
355
bs,
0 commit comments