1- #include " common.cuh"
1+ #include < algorithm>
2+ #include < cstdint>
3+
24#include " argmax.cuh"
5+ #include " common.cuh"
36#include " sum.cuh"
47
5- #include < cstdint>
8+ static __global__ void argmax_f32 (const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
9+ const int64_t row = blockIdx .x ;
610
7- static __global__ void argmax_f32 (
8- const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
11+ float maxval = -FLT_MAX;
12+ int argmax = -1 ;
13+ const float * rowx = x + row * ncols;
914
10- int argmax_thread = 0 ;
11- const int64_t row0 = (int64_t )blockIdx .x *WARP_SIZE;
15+ for (int32_t col = threadIdx .x ; col < ncols; col += blockDim .x ) {
16+ const float val = rowx[col];
17+ if (val > maxval) {
18+ maxval = val;
19+ argmax = col;
20+ }
21+ }
1222
1323#pragma unroll
14- for (int64_t row1 = 0 ; row1 < WARP_SIZE; ++row1) {
15- const int64_t row = row0 + row1;
16-
17- if (row >= nrows) {
18- break ;
24+ for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
25+ const float val = __shfl_xor_sync (0xFFFFFFFF , maxval, offset, WARP_SIZE);
26+ const int col = __shfl_xor_sync (0xFFFFFFFF , argmax, offset, WARP_SIZE);
27+ if (val > maxval) {
28+ maxval = val;
29+ argmax = col;
1930 }
31+ }
2032
21- float maxval = -FLT_MAX ;
22- int argmax = - 1 ;
23-
24- for ( int32_t col = threadIdx . x ; col < ncols; col += WARP_SIZE ) {
25- const float val = x[row*ncols + col] ;
26- const int bigger = val > maxval ;
27- const int not_bigger = bigger ^ 0x00000001 ;
28-
29- maxval = maxval*not_bigger + val*bigger ;
30- argmax = argmax*not_bigger + col*bigger ;
33+ const int n_warps = blockDim . x / WARP_SIZE ;
34+ const int lane_id = threadIdx . x % WARP_SIZE ;
35+ const int warp_id = threadIdx . x / WARP_SIZE;
36+ if (n_warps > 1 ) {
37+ constexpr int max_warps = 1024 / WARP_SIZE ;
38+ __shared__ float shared_maxval[max_warps] ;
39+ __shared__ int shared_argmax[max_warps] ;
40+ if (lane_id == 0 ) {
41+ shared_maxval[warp_id] = maxval;
42+ shared_argmax[warp_id] = argmax;
3143 }
3244
45+ __syncthreads ();
46+
47+ if (warp_id == 0 ) {
48+ if (lane_id < n_warps) {
49+ maxval = shared_maxval[lane_id];
50+ argmax = shared_argmax[lane_id];
51+ }
3352#pragma unroll
34- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
35- const float val = __shfl_xor_sync (0xFFFFFFFF , maxval, mask , WARP_SIZE);
36- const int col = __shfl_xor_sync (0xFFFFFFFF , argmax, mask , WARP_SIZE);
37- const int bigger = val > maxval;
38- const int not_bigger = bigger ^ 0x00000001 ;
39-
40- maxval = maxval*not_bigger + val*bigger;
41- argmax = argmax*not_bigger + col*bigger;
53+ for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
54+ const float val = __shfl_xor_sync (0xFFFFFFFF , maxval, offset , WARP_SIZE);
55+ const int col = __shfl_xor_sync (0xFFFFFFFF , argmax, offset , WARP_SIZE);
56+ if ( val > maxval) {
57+ maxval = val ;
58+ argmax = col;
59+ }
60+ }
4261 }
43-
44- const int store = row1 == threadIdx .x ;
45- argmax_thread += store*argmax;
4662 }
4763
48- const int row = row0 + threadIdx .x ;
49-
50- if (row >= nrows) {
51- return ;
64+ if (warp_id == 0 && lane_id == 0 ) {
65+ dst[row] = argmax;
5266 }
53-
54- dst[row] = argmax_thread;
5567}
5668
5769void ggml_cuda_argmax (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
7082
7183 cudaStream_t stream = ctx.stream ();
7284
73- const int64_t num_blocks = ( nrows + WARP_SIZE - 1 ) / WARP_SIZE ;
74-
75- const dim3 blocks_dim (WARP_SIZE , 1 , 1 );
85+ const int64_t num_blocks = nrows;
86+ const int64_t num_threads = std::min< int64_t >( 1024 , (ne00 + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE);
87+ const dim3 blocks_dim (num_threads , 1 , 1 );
7688 const dim3 blocks_num (num_blocks, 1 , 1 );
7789
78- argmax_f32<<<blocks_num, blocks_dim, 0 , stream>>> (src0_d, dst_d, ne00, nrows );
90+ argmax_f32<<<blocks_num, blocks_dim, 0 , stream>>> (src0_d, dst_d, ne00);
7991}
0 commit comments