| 
1 |  | -#include "common.cuh"  | 
 | 1 | +#include <algorithm>  | 
 | 2 | +#include <cstdint>  | 
 | 3 | + | 
2 | 4 | #include "argmax.cuh"  | 
 | 5 | +#include "common.cuh"  | 
3 | 6 | #include "sum.cuh"  | 
4 | 7 | 
 
  | 
5 |  | -#include <cstdint>  | 
 | 8 | +static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {  | 
 | 9 | +    const int64_t row = blockIdx.x;  | 
6 | 10 | 
 
  | 
7 |  | -static __global__ void argmax_f32(  | 
8 |  | -    const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {  | 
 | 11 | +    float maxval = -FLT_MAX;  | 
 | 12 | +    int   argmax = -1;  | 
9 | 13 | 
 
  | 
10 |  | -    int argmax_thread = 0;  | 
11 |  | -    const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;  | 
 | 14 | +    for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {  | 
 | 15 | +        const float val = x[row * ncols + col];  | 
 | 16 | +        if (val > maxval) {  | 
 | 17 | +            maxval = val;  | 
 | 18 | +            argmax = col;  | 
 | 19 | +        }  | 
 | 20 | +    }  | 
12 | 21 | 
 
  | 
13 | 22 | #pragma unroll  | 
14 |  | -    for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {  | 
15 |  | -        const int64_t row = row0 + row1;  | 
16 |  | - | 
17 |  | -        if (row >= nrows) {  | 
18 |  | -            break;  | 
 | 23 | +    for (int mask = 16; mask > 0; mask >>= 1) {  | 
 | 24 | +        const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);  | 
 | 25 | +        const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);  | 
 | 26 | +        if (val > maxval) {  | 
 | 27 | +            maxval = val;  | 
 | 28 | +            argmax = col;  | 
19 | 29 |         }  | 
 | 30 | +    }  | 
20 | 31 | 
 
  | 
21 |  | -        float maxval = -FLT_MAX;  | 
22 |  | -        int   argmax = -1;  | 
23 |  | - | 
24 |  | -        for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {  | 
25 |  | -            const float val        = x[row*ncols + col];  | 
26 |  | -            const int   bigger     = val > maxval;  | 
27 |  | -            const int   not_bigger = bigger ^ 0x00000001;  | 
28 |  | - | 
29 |  | -            maxval = maxval*not_bigger + val*bigger;  | 
30 |  | -            argmax = argmax*not_bigger + col*bigger;  | 
 | 32 | +    const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;  | 
 | 33 | +    const int lane_id = threadIdx.x % WARP_SIZE;  | 
 | 34 | +    const int warp_id = threadIdx.x / WARP_SIZE;  | 
 | 35 | +    if (n_warps > 1) {  | 
 | 36 | +        constexpr int    max_warps = 1024 / WARP_SIZE;  | 
 | 37 | +        __shared__ float shared_maxval[max_warps];  | 
 | 38 | +        __shared__ int   shared_argmax[max_warps];  | 
 | 39 | +        if (lane_id == 0) {  | 
 | 40 | +            shared_maxval[warp_id] = maxval;  | 
 | 41 | +            shared_argmax[warp_id] = argmax;  | 
31 | 42 |         }  | 
32 | 43 | 
 
  | 
33 |  | -#pragma unroll  | 
34 |  | -        for (int mask = 16; mask > 0; mask >>= 1) {  | 
35 |  | -            const float val        = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);  | 
36 |  | -            const int   col        = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);  | 
37 |  | -            const int   bigger     = val > maxval;  | 
38 |  | -            const int   not_bigger = bigger ^ 0x00000001;  | 
39 |  | - | 
40 |  | -            maxval = maxval*not_bigger + val*bigger;  | 
41 |  | -            argmax = argmax*not_bigger + col*bigger;  | 
 | 44 | +        if (lane_id >= n_warps) {  | 
 | 45 | +            shared_maxval[lane_id] = -FLT_MAX;  | 
 | 46 | +            shared_argmax[lane_id] = -1;  | 
 | 47 | +        }  | 
 | 48 | +        __syncthreads();  | 
 | 49 | + | 
 | 50 | +        if (warp_id == 0) {  | 
 | 51 | +            maxval = shared_maxval[lane_id];  | 
 | 52 | +            argmax = shared_argmax[lane_id];  | 
 | 53 | +            for (int mask = 16; mask > 0; mask >>= 1) {  | 
 | 54 | +                const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);  | 
 | 55 | +                const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);  | 
 | 56 | +                if (val > maxval) {  | 
 | 57 | +                    maxval = val;  | 
 | 58 | +                    argmax = col;  | 
 | 59 | +                }  | 
 | 60 | +            }  | 
42 | 61 |         }  | 
43 |  | - | 
44 |  | -        const int store = row1 == threadIdx.x;  | 
45 |  | -        argmax_thread += store*argmax;  | 
46 | 62 |     }  | 
47 | 63 | 
 
  | 
48 |  | -    const int row = row0 + threadIdx.x;  | 
49 |  | - | 
50 |  | -    if (row >= nrows) {  | 
51 |  | -        return;  | 
 | 64 | +    if (warp_id == 0 && lane_id == 0) {  | 
 | 65 | +        dst[row] = argmax;  | 
52 | 66 |     }  | 
53 |  | - | 
54 |  | -    dst[row] = argmax_thread;  | 
55 | 67 | }  | 
56 | 68 | 
 
  | 
57 | 69 | void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {  | 
58 | 70 |     const ggml_tensor * src0 = dst->src[0];  | 
59 | 71 | 
 
  | 
60 | 72 |     GGML_ASSERT(src0->type == GGML_TYPE_F32);  | 
61 |  | -    GGML_ASSERT( dst->type == GGML_TYPE_I32);  | 
 | 73 | +    GGML_ASSERT(dst->type == GGML_TYPE_I32);  | 
62 | 74 | 
 
  | 
63 | 75 |     GGML_ASSERT(ggml_is_contiguous(src0));  | 
64 | 76 | 
 
  | 
65 | 77 |     const int64_t ne00  = src0->ne[0];  | 
66 | 78 |     const int64_t nrows = ggml_nrows(src0);  | 
67 | 79 | 
 
  | 
68 | 80 |     const float * src0_d = (const float *) src0->data;  | 
69 |  | -    int32_t     * dst_d  = (int32_t     *) dst->data;  | 
 | 81 | +    int32_t *     dst_d  = (int32_t     *) dst->data;  | 
70 | 82 | 
 
  | 
71 | 83 |     cudaStream_t stream = ctx.stream();  | 
72 | 84 | 
 
  | 
73 |  | -    const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;  | 
74 |  | - | 
75 |  | -    const dim3 blocks_dim(WARP_SIZE, 1, 1);  | 
 | 85 | +    const int64_t num_blocks = nrows;  | 
 | 86 | +    const dim3 blocks_dim(std::min<int64_t>(ne00, 1024), 1, 1);  | 
76 | 87 |     const dim3 blocks_num(num_blocks, 1, 1);  | 
77 | 88 | 
 
  | 
78 | 89 |     argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);  | 
 | 
0 commit comments