|
1 | | -#include "common.cuh" |
| 1 | +#include <algorithm> |
| 2 | +#include <cstdint> |
| 3 | + |
2 | 4 | #include "argmax.cuh" |
| 5 | +#include "common.cuh" |
3 | 6 | #include "sum.cuh" |
4 | 7 |
|
5 | | -#include <cstdint> |
| 8 | +static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { |
| 9 | + const int64_t row = blockIdx.x; |
6 | 10 |
|
7 | | -static __global__ void argmax_f32( |
8 | | - const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { |
| 11 | + float maxval = -FLT_MAX; |
| 12 | + int argmax = -1; |
9 | 13 |
|
10 | | - int argmax_thread = 0; |
11 | | - const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE; |
| 14 | + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { |
| 15 | + const float val = x[row * ncols + col]; |
| 16 | + if (val > maxval) { |
| 17 | + maxval = val; |
| 18 | + argmax = col; |
| 19 | + } |
| 20 | + } |
12 | 21 |
|
13 | 22 | #pragma unroll |
14 | | - for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) { |
15 | | - const int64_t row = row0 + row1; |
16 | | - |
17 | | - if (row >= nrows) { |
18 | | - break; |
| 23 | + for (int mask = 16; mask > 0; mask >>= 1) { |
| 24 | + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); |
| 25 | + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); |
| 26 | + if (val > maxval) { |
| 27 | + maxval = val; |
| 28 | + argmax = col; |
19 | 29 | } |
| 30 | + } |
20 | 31 |
|
21 | | - float maxval = -FLT_MAX; |
22 | | - int argmax = -1; |
23 | | - |
24 | | - for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) { |
25 | | - const float val = x[row*ncols + col]; |
26 | | - const int bigger = val > maxval; |
27 | | - const int not_bigger = bigger ^ 0x00000001; |
28 | | - |
29 | | - maxval = maxval*not_bigger + val*bigger; |
30 | | - argmax = argmax*not_bigger + col*bigger; |
| 32 | + const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; |
| 33 | + const int lane_id = threadIdx.x % WARP_SIZE; |
| 34 | + const int warp_id = threadIdx.x / WARP_SIZE; |
| 35 | + if (n_warps > 1) { |
| 36 | + constexpr int max_warps = 1024 / WARP_SIZE; |
| 37 | + __shared__ float shared_maxval[max_warps]; |
| 38 | + __shared__ int shared_argmax[max_warps]; |
| 39 | + if (lane_id == 0) { |
| 40 | + shared_maxval[warp_id] = maxval; |
| 41 | + shared_argmax[warp_id] = argmax; |
31 | 42 | } |
32 | 43 |
|
33 | | -#pragma unroll |
34 | | - for (int mask = 16; mask > 0; mask >>= 1) { |
35 | | - const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); |
36 | | - const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); |
37 | | - const int bigger = val > maxval; |
38 | | - const int not_bigger = bigger ^ 0x00000001; |
39 | | - |
40 | | - maxval = maxval*not_bigger + val*bigger; |
41 | | - argmax = argmax*not_bigger + col*bigger; |
| 44 | + if (lane_id >= n_warps) { |
| 45 | + shared_maxval[lane_id] = -FLT_MAX; |
| 46 | + shared_argmax[lane_id] = -1; |
| 47 | + } |
| 48 | + __syncthreads(); |
| 49 | + |
| 50 | + if (warp_id == 0) { |
| 51 | + maxval = shared_maxval[lane_id]; |
| 52 | + argmax = shared_argmax[lane_id]; |
| 53 | + for (int mask = 16; mask > 0; mask >>= 1) { |
| 54 | + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); |
| 55 | + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); |
| 56 | + if (val > maxval) { |
| 57 | + maxval = val; |
| 58 | + argmax = col; |
| 59 | + } |
| 60 | + } |
42 | 61 | } |
43 | | - |
44 | | - const int store = row1 == threadIdx.x; |
45 | | - argmax_thread += store*argmax; |
46 | 62 | } |
47 | 63 |
|
48 | | - const int row = row0 + threadIdx.x; |
49 | | - |
50 | | - if (row >= nrows) { |
51 | | - return; |
| 64 | + if (warp_id == 0 && lane_id == 0) { |
| 65 | + dst[row] = argmax; |
52 | 66 | } |
53 | | - |
54 | | - dst[row] = argmax_thread; |
55 | 67 | } |
56 | 68 |
|
57 | 69 | void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
58 | 70 | const ggml_tensor * src0 = dst->src[0]; |
59 | 71 |
|
60 | 72 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
61 | | - GGML_ASSERT( dst->type == GGML_TYPE_I32); |
| 73 | + GGML_ASSERT(dst->type == GGML_TYPE_I32); |
62 | 74 |
|
63 | 75 | GGML_ASSERT(ggml_is_contiguous(src0)); |
64 | 76 |
|
65 | 77 | const int64_t ne00 = src0->ne[0]; |
66 | 78 | const int64_t nrows = ggml_nrows(src0); |
67 | 79 |
|
68 | 80 | const float * src0_d = (const float *) src0->data; |
69 | | - int32_t * dst_d = (int32_t *) dst->data; |
| 81 | + int32_t * dst_d = (int32_t *) dst->data; |
70 | 82 |
|
71 | 83 | cudaStream_t stream = ctx.stream(); |
72 | 84 |
|
73 | | - const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE; |
74 | | - |
75 | | - const dim3 blocks_dim(WARP_SIZE, 1, 1); |
| 85 | + const int64_t num_blocks = nrows; |
| 86 | + const dim3 blocks_dim(std::min<int64_t>(ne00, 1024), 1, 1); |
76 | 87 | const dim3 blocks_num(num_blocks, 1, 1); |
77 | 88 |
|
78 | 89 | argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows); |
|
0 commit comments