Skip to content

Commit ec6b2e3

Browse files
author
Gitty Burstein
committed
Fix: move SparseK op_params indices to 7..10 to avoid overlap with precision params
Co-authored-by: Yael Shuker <[email protected]> Co-authored-by: Gitty Burstein <[email protected]>
1 parent 4815f37 commit ec6b2e3

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
#define GGML_MAX_PARAMS 2048
220220
#define GGML_MAX_SRC 10
221221
#define GGML_MAX_N_THREADS 512
222-
#define GGML_MAX_OP_PARAMS 128
222+
#define GGML_MAX_OP_PARAMS 64
223223

224224
#ifndef GGML_MAX_NAME
225225
# define GGML_MAX_NAME 64

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8004,10 +8004,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
80048004
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
80058005

80068006
// -------- SparseK op_params --------
8007-
const bool use_sparsek = ggml_get_op_params_i32(dst, 3) != 0;
8008-
const int32_t k_top = ggml_get_op_params_i32(dst, 4);
8009-
const int32_t win_local = ggml_get_op_params_i32(dst, 5);
8010-
const int32_t stride_glb = ggml_get_op_params_i32(dst, 6);
8007+
const bool use_sparsek = ggml_get_op_params_i32(dst, 7) != 0;
8008+
const int32_t k_top = ggml_get_op_params_i32(dst, 8);
8009+
const int32_t win_local = ggml_get_op_params_i32(dst, 9);
8010+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 10);
80118011
// ----------------------------------------------------------------------------
80128012

80138013
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;

ggml/src/ggml.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5214,10 +5214,10 @@ struct ggml_tensor * ggml_flash_attn_back(
52145214
return result;
52155215
}
52165216

5217-
#define GGML_FA_EXT_PARAM_SPARSEK_FLAG 3
5218-
#define GGML_FA_EXT_PARAM_SPARSEK_KTOP 4
5219-
#define GGML_FA_EXT_PARAM_SPARSEK_WIN 5
5220-
#define GGML_FA_EXT_PARAM_SPARSEK_STRIDE 6
5217+
#define GGML_FA_EXT_PARAM_SPARSEK_FLAG 7
5218+
#define GGML_FA_EXT_PARAM_SPARSEK_KTOP 8
5219+
#define GGML_FA_EXT_PARAM_SPARSEK_WIN 9
5220+
#define GGML_FA_EXT_PARAM_SPARSEK_STRIDE 10
52215221

52225222
void ggml_flash_attn_ext_set_sparsek(struct ggml_tensor * a,
52235223
bool use_sparsek,

0 commit comments

Comments
 (0)