Skip to content

Commit 14eb5d4

Browse files
committed
cuda : optimize argmax
1 parent 9abe9ee commit 14eb5d4

File tree

2 files changed

+84
-44
lines changed

2 files changed

+84
-44
lines changed

ggml/src/ggml-cuda/argmax.cu

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,89 @@
1-
#include "common.cuh"
1+
#include <algorithm>
2+
#include <cstdint>
3+
24
#include "argmax.cuh"
5+
#include "common.cuh"
36
#include "sum.cuh"
47

5-
#include <cstdint>
8+
static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
9+
const int64_t row = blockIdx.x;
610

7-
static __global__ void argmax_f32(
8-
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
11+
float maxval = -FLT_MAX;
12+
int argmax = -1;
913

10-
int argmax_thread = 0;
11-
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
14+
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
15+
const float val = x[row * ncols + col];
16+
if (val > maxval) {
17+
maxval = val;
18+
argmax = col;
19+
}
20+
}
1221

1322
#pragma unroll
14-
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
15-
const int64_t row = row0 + row1;
16-
17-
if (row >= nrows) {
18-
break;
23+
for (int mask = 16; mask > 0; mask >>= 1) {
24+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
25+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
26+
if (val > maxval) {
27+
maxval = val;
28+
argmax = col;
1929
}
30+
}
2031

21-
float maxval = -FLT_MAX;
22-
int argmax = -1;
23-
24-
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
25-
const float val = x[row*ncols + col];
26-
const int bigger = val > maxval;
27-
const int not_bigger = bigger ^ 0x00000001;
28-
29-
maxval = maxval*not_bigger + val*bigger;
30-
argmax = argmax*not_bigger + col*bigger;
32+
const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
33+
const int lane_id = threadIdx.x % WARP_SIZE;
34+
const int warp_id = threadIdx.x / WARP_SIZE;
35+
if (n_warps > 1) {
36+
constexpr int max_warps = 1024 / WARP_SIZE;
37+
__shared__ float shared_maxval[max_warps];
38+
__shared__ int shared_argmax[max_warps];
39+
if (lane_id == 0) {
40+
shared_maxval[warp_id] = maxval;
41+
shared_argmax[warp_id] = argmax;
3142
}
3243

33-
#pragma unroll
34-
for (int mask = 16; mask > 0; mask >>= 1) {
35-
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
36-
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
37-
const int bigger = val > maxval;
38-
const int not_bigger = bigger ^ 0x00000001;
39-
40-
maxval = maxval*not_bigger + val*bigger;
41-
argmax = argmax*not_bigger + col*bigger;
44+
if (lane_id >= n_warps) {
45+
shared_maxval[lane_id] = -FLT_MAX;
46+
shared_argmax[lane_id] = -1;
47+
}
48+
__syncthreads();
49+
50+
if (warp_id == 0) {
51+
maxval = shared_maxval[lane_id];
52+
argmax = shared_argmax[lane_id];
53+
for (int mask = 16; mask > 0; mask >>= 1) {
54+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
55+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
56+
if (val > maxval) {
57+
maxval = val;
58+
argmax = col;
59+
}
60+
}
4261
}
43-
44-
const int store = row1 == threadIdx.x;
45-
argmax_thread += store*argmax;
4662
}
4763

48-
const int row = row0 + threadIdx.x;
49-
50-
if (row >= nrows) {
51-
return;
64+
if (warp_id == 0 && lane_id == 0) {
65+
dst[row] = argmax;
5266
}
53-
54-
dst[row] = argmax_thread;
5567
}
5668

5769
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5870
const ggml_tensor * src0 = dst->src[0];
5971

6072
GGML_ASSERT(src0->type == GGML_TYPE_F32);
61-
GGML_ASSERT( dst->type == GGML_TYPE_I32);
73+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
6274

6375
GGML_ASSERT(ggml_is_contiguous(src0));
6476

6577
const int64_t ne00 = src0->ne[0];
6678
const int64_t nrows = ggml_nrows(src0);
6779

6880
const float * src0_d = (const float *) src0->data;
69-
int32_t * dst_d = (int32_t *) dst->data;
81+
int32_t * dst_d = (int32_t *) dst->data;
7082

7183
cudaStream_t stream = ctx.stream();
7284

73-
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
74-
75-
const dim3 blocks_dim(WARP_SIZE, 1, 1);
85+
const int64_t num_blocks = nrows;
86+
const dim3 blocks_dim(std::min<int64_t>(ne00, 1024), 1, 1);
7687
const dim3 blocks_num(num_blocks, 1, 1);
7788

7889
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);

tests/test-backend-ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,26 @@ struct test_argmax : public test_case {
11551155
return out;
11561156
}
11571157

1158+
void initialize_tensors(ggml_context * ctx) override {
1159+
std::random_device rd;
1160+
std::default_random_engine rng(rd());
1161+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1162+
if (t->type == GGML_TYPE_F32) {
1163+
// initialize with unique values to avoid ties
1164+
for (int64_t r = 0; r < ggml_nrows(t); r++) {
1165+
std::vector<float> data(t->ne[0]);
1166+
for (int i = 0; i < t->ne[0]; i++) {
1167+
data[i] = i;
1168+
}
1169+
std::shuffle(data.begin(), data.end(), rng);
1170+
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
1171+
}
1172+
} else {
1173+
init_tensor_uniform(t);
1174+
}
1175+
}
1176+
}
1177+
11581178
double max_nmse_err() override {
11591179
return 0.0;
11601180
}
@@ -3441,6 +3461,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34413461
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
34423462

34433463
test_cases.emplace_back(new test_argmax());
3464+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3465+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
3466+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3467+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
3468+
34443469
test_cases.emplace_back(new test_count_equal());
34453470

34463471
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
@@ -3831,6 +3856,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
38313856
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
38323857
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
38333858

3859+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
3860+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3861+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
3862+
38343863
for (int bs : {1, 512}) {
38353864
for (ggml_type type_a : all_types) {
38363865
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)