@@ -66,7 +66,6 @@ __global__ void GatherKernelV2(const T* inputs,
66
66
const int * index_groups,
67
67
const int non_zero_num,
68
68
const int kernel_size,
69
- const int max_voxel,
70
69
const int channels,
71
70
const int buffer_count,
72
71
T* output) {
@@ -84,11 +83,10 @@ __global__ void GatherKernelV2(const T* inputs,
84
83
#pragma unroll
85
84
for (int it = 0 ; it < buffer_count; it++) {
86
85
int len = index_counts[indices_i + it * non_zero_num];
87
- const int group_offset = it * kernel_size * max_voxel * non_zero_num;
86
+ const int group_offset = it * kernel_size * non_zero_num;
88
87
#pragma unroll
89
88
for (int j = 0 ; j < len; j++) {
90
- int out_i = index_groups[indices_i * kernel_size * max_voxel + j +
91
- group_offset];
89
+ int out_i = index_groups[indices_i * kernel_size + j + group_offset];
92
90
phi::Store<T, VecSize>(
93
91
in_vec, output + out_i * channels + channels_i * VecSize);
94
92
}
@@ -130,7 +128,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
130
128
const int * index_groups,
131
129
const int non_zero_num,
132
130
const int kernel_size,
133
- const int max_voxel,
134
131
const int channels,
135
132
const int buffer_count,
136
133
T* output) {
@@ -146,7 +143,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
146
143
index_groups,
147
144
non_zero_num,
148
145
kernel_size,
149
- max_voxel,
150
146
channels,
151
147
buffer_count,
152
148
output);
@@ -161,7 +157,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
161
157
index_groups,
162
158
non_zero_num,
163
159
kernel_size,
164
- max_voxel,
165
160
channels,
166
161
buffer_count,
167
162
output);
@@ -207,7 +202,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
207
202
template <typename IntT>
208
203
__global__ void GroupIndexs (const int * out_index_table,
209
204
const int n,
210
- const int offset ,
205
+ const int kernel_size ,
211
206
IntT* out_indexs,
212
207
int * out_index_counts,
213
208
int * out_index_groups) {
@@ -219,7 +214,7 @@ __global__ void GroupIndexs(const int* out_index_table,
219
214
// kernel_size at most
220
215
int j = atomicAdd (out_index_counts + real_index, 1 );
221
216
// nnz * kernel_size
222
- out_index_groups[real_index * offset + j] = i;
217
+ out_index_groups[real_index * kernel_size + j] = i;
223
218
}
224
219
}
225
220
@@ -303,36 +298,18 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
303
298
}
304
299
}
305
300
306
- template <typename IntT, bool save_out_index = true >
301
+ template <typename IntT>
307
302
__global__ void GetOutIndexTable (const IntT* indices,
308
303
const IntT non_zero_num,
309
304
const Dims4D dims,
310
- int * out_index_table,
311
- int * out_index_table2,
312
- int * max_voxel) {
313
- __shared__ int cache_max;
314
- if (threadIdx.x == 0 ) {
315
- cache_max = 0 ;
316
- }
317
- __syncthreads ();
318
-
305
+ int * out_index_table) {
319
306
CUDA_KERNEL_LOOP_TYPE (i, non_zero_num, int64_t ) {
320
307
IntT batch = indices[i];
321
308
IntT in_z = indices[i + non_zero_num];
322
309
IntT in_y = indices[i + 2 * non_zero_num];
323
310
IntT in_x = indices[i + 3 * non_zero_num];
324
311
IntT index = PointToIndex (batch, in_x, in_y, in_z, dims);
325
- if (save_out_index) {
326
- out_index_table[index] = i == 0 ? -1 : i;
327
- }
328
-
329
- int count = atomicAdd (out_index_table2 + index, 1 );
330
- atomicMax (&cache_max, count);
331
- }
332
-
333
- __syncthreads ();
334
- if (threadIdx.x == 0 ) {
335
- atomicMax (max_voxel, cache_max + 1 );
312
+ out_index_table[index] = i == 0 ? -1 : i;
336
313
}
337
314
}
338
315
@@ -341,22 +318,10 @@ __global__ void GetOutIndexTable(int* indexs,
341
318
const int non_zero_num,
342
319
const Dims4D out_dims,
343
320
int * out_index_table,
344
- int * out_index_table2,
345
- int * max_voxel,
346
321
IntT* out_indices) {
347
- __shared__ int cache_max;
348
- if (threadIdx.x == 0 ) {
349
- cache_max = 0 ;
350
- }
351
- __syncthreads ();
352
-
353
322
CUDA_KERNEL_LOOP_TYPE (i, non_zero_num, int64_t ) {
354
323
IntT index = static_cast <IntT>(indexs[i]);
355
324
out_index_table[index] = i;
356
-
357
- int count = atomicAdd (out_index_table2 + index, 1 );
358
- atomicMax (&cache_max, count);
359
-
360
325
IntT batch, x, y, z;
361
326
phi::funcs::sparse::IndexToPoint<Dims4D>(
362
327
index, out_dims, &batch, &x, &y, &z);
@@ -367,11 +332,6 @@ __global__ void GetOutIndexTable(int* indexs,
367
332
out_indices[i + non_zero_num * 3 ] = x;
368
333
indexs[i] = 0 ;
369
334
}
370
-
371
- __syncthreads ();
372
- if (threadIdx.x == 0 ) {
373
- atomicMax (max_voxel, cache_max + 1 );
374
- }
375
335
}
376
336
377
337
template <typename IntT>
@@ -491,7 +451,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
491
451
492
452
template <typename IntT>
493
453
__global__ void GroupIndexs (const int n,
494
- const int offset ,
454
+ const int kernel_size ,
495
455
const IntT* indexs,
496
456
int * index_counts,
497
457
int * index_groups) {
@@ -500,15 +460,15 @@ __global__ void GroupIndexs(const int n,
500
460
// kernel_size at most
501
461
int j = atomicAdd (index_counts + index, 1 );
502
462
// nnz * kernel_size
503
- index_groups[index * offset + j] = i;
463
+ index_groups[index * kernel_size + j] = i;
504
464
}
505
465
}
506
466
507
467
// double space to reduce atomicAdd conflict
508
468
template <typename IntT>
509
469
__global__ void GroupIndexsV2 (const int rulebook_len,
510
470
const int non_zero_num,
511
- const int offset ,
471
+ const int kernel_size ,
512
472
const int half_kernel_offset,
513
473
const IntT* indexs,
514
474
int * index_counts,
@@ -519,11 +479,11 @@ __global__ void GroupIndexsV2(const int rulebook_len,
519
479
i < half_kernel_offset ? index_counts : index_counts + non_zero_num;
520
480
int * groups_ptr = i < half_kernel_offset
521
481
? index_groups
522
- : index_groups + non_zero_num * offset ;
482
+ : index_groups + non_zero_num * kernel_size ;
523
483
// conflict kernel_size times at most
524
484
int j = atomicAdd (counts_ptr + index, 1 );
525
485
// nnz * kernel_size
526
- groups_ptr[index * offset + j] = i;
486
+ groups_ptr[index * kernel_size + j] = i;
527
487
}
528
488
}
529
489
@@ -622,10 +582,6 @@ int ProductRuleBook(const Context& dev_ctx,
622
582
DenseTensor out_index_table = phi::Empty<int >(dev_ctx, {table_size});
623
583
int * out_index_table_ptr = out_index_table.data <int >();
624
584
625
- DenseTensor out_index_table2 = phi::Empty<int >(dev_ctx, {table_size + 1 });
626
- int * out_index_table2_ptr = out_index_table2.data <int >();
627
- int * h_max_voxel = h_counter + kernel_size;
628
-
629
585
if (subm) {
630
586
DenseTensor tmp_rulebook = phi::Empty (dev_ctx, std::move (rulebook_meta));
631
587
IntT* rulebook_ptr = tmp_rulebook.data <IntT>();
@@ -636,29 +592,14 @@ int ProductRuleBook(const Context& dev_ctx,
636
592
637
593
phi::backends::gpu::GpuMemsetAsync (
638
594
out_index_table_ptr, 0 , sizeof (int ) * table_size, dev_ctx.stream ());
639
- phi::backends::gpu::GpuMemsetAsync (out_index_table2_ptr,
640
- 0 ,
641
- sizeof (int ) * (table_size + 1 ),
642
- dev_ctx.stream ());
643
595
644
596
auto config =
645
597
phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, non_zero_num, 1 );
646
- GetOutIndexTable<IntT>
647
- <<<config.block_per_grid ,
648
- config.thread_per_block ,
649
- 0 ,
650
- dev_ctx.stream ()>>>(out_indices.data <IntT>(),
651
- non_zero_num,
652
- d_x_dims,
653
- out_index_table_ptr,
654
- out_index_table2_ptr,
655
- out_index_table2_ptr + table_size);
656
- phi::backends::gpu::GpuMemcpyAsync (h_max_voxel,
657
- out_index_table2_ptr + table_size,
658
- sizeof (int ),
659
- gpuMemcpyDeviceToHost,
660
- dev_ctx.stream ());
661
- dev_ctx.Wait ();
598
+ GetOutIndexTable<IntT><<<config.block_per_grid ,
599
+ config.thread_per_block ,
600
+ 0 ,
601
+ dev_ctx.stream ()>>>(
602
+ out_indices.data <IntT>(), non_zero_num, d_x_dims, out_index_table_ptr);
662
603
663
604
size_t cache_size =
664
605
kernel_size * 2 * sizeof (int ) +
@@ -712,22 +653,6 @@ int ProductRuleBook(const Context& dev_ctx,
712
653
out_rulebook_ptr);
713
654
*rulebook = out_rulebook;
714
655
715
- unique_value->ResizeAndAllocate (
716
- {static_cast <int >(non_zero_num * h_max_voxel[0 ] * kernel_size)});
717
- int * unique_value_ptr = unique_value->data <int >();
718
- out_index->ResizeAndAllocate ({static_cast <int >(rulebook_len)});
719
- int * out_index_ptr = out_index->data <int >();
720
- phi::backends::gpu::GpuMemsetAsync (
721
- out_index_ptr, 0 , sizeof (int ) * rulebook_len, dev_ctx.stream ());
722
- GroupIndexs<<<config.block_per_grid ,
723
- config.thread_per_block ,
724
- 0 ,
725
- dev_ctx.stream ()>>>(rulebook_len,
726
- kernel_size * h_max_voxel[0 ],
727
- out_rulebook_ptr + rulebook_len,
728
- out_index_ptr,
729
- unique_value_ptr);
730
-
731
656
return rulebook_len;
732
657
733
658
} else {
@@ -811,43 +736,25 @@ int ProductRuleBook(const Context& dev_ctx,
811
736
812
737
IntT* out_indices_ptr = out_indices.data <IntT>();
813
738
814
- phi::backends::gpu::GpuMemsetAsync (
815
- out_index_table_ptr, 0 , sizeof (int ) * table_size, dev_ctx.stream ());
816
- phi::backends::gpu::GpuMemsetAsync (out_index_table2_ptr,
817
- 0 ,
818
- sizeof (int ) * (table_size + 1 ),
819
- dev_ctx.stream ());
820
-
821
739
config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, out_nnz, 1 );
822
- GetOutIndexTable<IntT>
823
- <<<config.block_per_grid ,
824
- config.thread_per_block ,
825
- 0 ,
826
- dev_ctx.stream ()>>>(out_index_ptr,
827
- out_nnz,
828
- d_out_dims,
829
- out_index_table_ptr,
830
- out_index_table2_ptr,
831
- out_index_table2_ptr + table_size,
832
- out_indices_ptr);
833
- phi::backends::gpu::GpuMemcpyAsync (h_max_voxel,
834
- out_index_table2_ptr + table_size,
835
- sizeof (int ),
836
- gpuMemcpyDeviceToHost,
837
- dev_ctx.stream ());
838
- dev_ctx.Wait ();
839
-
740
+ GetOutIndexTable<IntT><<<config.block_per_grid ,
741
+ config.thread_per_block ,
742
+ 0 ,
743
+ dev_ctx.stream ()>>>(out_index_ptr,
744
+ out_nnz,
745
+ d_out_dims,
746
+ out_index_table_ptr,
747
+ out_indices_ptr);
840
748
config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, rulebook_len, 1 );
841
- unique_value->ResizeAndAllocate (
842
- {static_cast <int >(out_nnz * h_max_voxel[0 ] * kernel_size)});
749
+ unique_value->ResizeAndAllocate ({static_cast <int >(out_nnz * kernel_size)});
843
750
int * unique_value_ptr = unique_value->data <int >();
844
751
845
752
GroupIndexs<<<config.block_per_grid ,
846
753
config.thread_per_block ,
847
754
0 ,
848
755
dev_ctx.stream ()>>>(out_index_table_ptr,
849
756
rulebook_len,
850
- kernel_size * h_max_voxel[ 0 ] ,
757
+ kernel_size,
851
758
rulebook_ptr + rulebook_len,
852
759
out_index_ptr,
853
760
unique_value_ptr);
0 commit comments