|
13 | 13 | #include <THC/THCAtomics.cuh> |
14 | 14 | #include <THC/THCDeviceUtils.cuh> |
15 | 15 |
|
16 | | -#define CUDA_CHECK(call) if((call) != cudaSuccess) {cudaError_t err = cudaGetLastError(); std::cout << "CUDA error calling ""#call"", code is " << err << std::endl;} |
17 | | - |
18 | 16 | #define CUDA_NUM_THREADS 64 |
19 | 17 | #define GET_CUDA_BLOCKS(N) ceil((float)N / CUDA_NUM_THREADS) |
20 | 18 |
|
21 | 19 | __global__ void adj_vec_kernel( |
22 | 20 | int batch_size, |
23 | 21 | int * edge_index, |
24 | 22 | int vertex_count, |
25 | | - int * adj_vec, |
| 23 | + int * adj_vec, |
26 | 24 | int * adj_vec_len, |
27 | 25 | int max_adj_per_node){ |
28 | 26 |
|
29 | | - const int edge_count = vertex_count - 1; |
30 | | - const int batch_idx = blockIdx.x; |
31 | | - const int thread_idx = threadIdx.x; |
32 | | - const int thread_count = blockDim.x; |
33 | | - |
34 | | - edge_index += batch_idx * edge_count * 2; |
35 | | - adj_vec += batch_idx * vertex_count * max_adj_per_node; |
36 | | - adj_vec_len += batch_idx * vertex_count; |
37 | | - |
38 | | - for (int i = thread_idx; i < edge_count; i += thread_count){ |
39 | | - int source = edge_index[2 * i]; |
40 | | - int target = edge_index[2 * i + 1]; |
41 | | - int source_len = atomicAdd(&(adj_vec_len[source]), 1); |
42 | | - adj_vec[source * max_adj_per_node + source_len] = target; |
43 | | - int target_len = atomicAdd(&(adj_vec_len[target]), 1); |
44 | | - adj_vec[target * max_adj_per_node + target_len] = source; |
45 | | - } |
| 27 | + const int edge_count = vertex_count - 1; |
| 28 | + const int batch_idx = blockIdx.x; |
| 29 | + const int thread_idx = threadIdx.x; |
| 30 | + const int thread_count = blockDim.x; |
| 31 | + |
| 32 | + edge_index += batch_idx * edge_count * 2; |
| 33 | + adj_vec += batch_idx * vertex_count * max_adj_per_node; |
| 34 | + adj_vec_len += batch_idx * vertex_count; |
| 35 | + |
| 36 | + for (int i = thread_idx; i < edge_count; i += thread_count){ |
| 37 | + int source = edge_index[2 * i]; |
| 38 | + int target = edge_index[2 * i + 1]; |
| 39 | + int source_len = atomicAdd(&(adj_vec_len[source]), 1); |
| 40 | + adj_vec[source * max_adj_per_node + source_len] = target; |
| 41 | + int target_len = atomicAdd(&(adj_vec_len[target]), 1); |
| 42 | + adj_vec[target * max_adj_per_node + target_len] = source; |
| 43 | + } |
46 | 44 | } |
47 | 45 |
|
48 | 46 | __global__ void breadth_first_sort_kernel( |
49 | | - int * sorted_index, |
50 | | - int * sorted_parent_index, |
51 | | - int * sorted_child_index, |
52 | | - int * adj_vec, |
53 | | - int * adj_vec_len, |
54 | | - int * parent_index, |
55 | | - int batch_size, |
56 | | - int vertex_count, |
| 47 | + int * sorted_index, |
| 48 | + int * sorted_parent_index, |
| 49 | + int * sorted_child_index, |
| 50 | + int * adj_vec, |
| 51 | + int * adj_vec_len, |
| 52 | + int * parent_index, |
| 53 | + int batch_size, |
| 54 | + int vertex_count, |
57 | 55 | int max_adj_per_node){ |
58 | 56 |
|
59 | | - const int batch_idx = blockIdx.x; |
60 | | - const int thread_idx = threadIdx.x; |
61 | | - const int thread_count = blockDim.x; |
62 | | - |
63 | | - adj_vec += batch_idx * vertex_count * max_adj_per_node; |
64 | | - adj_vec_len += batch_idx * vertex_count; |
65 | | - parent_index += batch_idx * vertex_count; |
66 | | - sorted_index += batch_idx * vertex_count; |
67 | | - sorted_parent_index += batch_idx * vertex_count; |
68 | | - sorted_child_index += batch_idx * vertex_count * max_adj_per_node; |
69 | | - |
70 | | - __shared__ int sorted_len; |
71 | | - if (thread_idx == 0) { |
72 | | - sorted_len = 1; |
73 | | - parent_index[0] = 0; |
74 | | - sorted_index[0] = 0; |
75 | | - sorted_parent_index[0] = 0; |
76 | | - } |
77 | | - __syncthreads(); |
78 | | - |
79 | | - int i = thread_idx; |
80 | | - while (i < vertex_count){ |
81 | | - if ((sorted_index[i] > 0) || (i == 0)){ |
82 | | - int child_index = 0; |
83 | | - int par = parent_index[i]; |
84 | | - int cur = sorted_index[i]; |
85 | | - for (int j = 0; j < adj_vec_len[cur]; j++){ |
86 | | - int child = adj_vec[cur * max_adj_per_node + j]; |
87 | | - if (child != par){ |
88 | | - int pos = atomicAdd(&(sorted_len), 1); |
89 | | - sorted_index[pos] = child; |
90 | | - parent_index[pos] = cur; |
91 | | - sorted_parent_index[pos] = i; |
92 | | - sorted_child_index[i * max_adj_per_node + child_index] = pos; |
| 57 | + const int batch_idx = blockIdx.x; |
| 58 | + const int thread_idx = threadIdx.x; |
| 59 | + const int thread_count = blockDim.x; |
| 60 | + |
| 61 | + adj_vec += batch_idx * vertex_count * max_adj_per_node; |
| 62 | + adj_vec_len += batch_idx * vertex_count; |
| 63 | + parent_index += batch_idx * vertex_count; |
| 64 | + sorted_index += batch_idx * vertex_count; |
| 65 | + sorted_parent_index += batch_idx * vertex_count; |
| 66 | + sorted_child_index += batch_idx * vertex_count * max_adj_per_node; |
| 67 | + |
| 68 | + __shared__ int sorted_len; |
| 69 | + if (thread_idx == 0) { |
| 70 | + sorted_len = 1; |
| 71 | + parent_index[0] = 0; |
| 72 | + sorted_index[0] = 0; |
| 73 | + sorted_parent_index[0] = 0; |
| 74 | + } |
| 75 | + __syncthreads(); |
| 76 | + |
| 77 | + int i = thread_idx; |
| 78 | + while (i < vertex_count){ |
| 79 | + if ((sorted_index[i] > 0) || (i == 0)){ |
| 80 | + int child_index = 0; |
| 81 | + int par = parent_index[i]; |
| 82 | + int cur = sorted_index[i]; |
| 83 | + for (int j = 0; j < adj_vec_len[cur]; j++){ |
| 84 | + int child = adj_vec[cur * max_adj_per_node + j]; |
| 85 | + if (child != par){ |
| 86 | + int pos = atomicAdd(&(sorted_len), 1); |
| 87 | + sorted_index[pos] = child; |
| 88 | + parent_index[pos] = cur; |
| 89 | + sorted_parent_index[pos] = i; |
| 90 | + sorted_child_index[i * max_adj_per_node + child_index] = pos; |
93 | 91 | child_index++; |
94 | | - } |
95 | | - } |
96 | | - i += thread_count; |
97 | | - } |
98 | | - __syncthreads(); |
99 | | - } |
| 92 | + } |
| 93 | + } |
| 94 | + i += thread_count; |
| 95 | + } |
| 96 | + __syncthreads(); |
| 97 | + } |
100 | 98 | } |
101 | 99 |
|
102 | 100 | std::tuple<at::Tensor, at::Tensor, at::Tensor> |
@@ -135,5 +133,3 @@ bfs_forward( |
135 | 133 |
|
136 | 134 | return std::make_tuple(sorted_index_tensor, sorted_parent_tensor, sorted_child_tensor); |
137 | 135 | } |
138 | | - |
139 | | - |
0 commit comments