Skip to content

Commit aa7f6a8

Browse files
committed
vulkan: copy iq4_nl LUT into shared memory
1 parent 3ee6382 commit aa7f6a8

File tree

6 files changed

+29
-4
lines changed

6 files changed

+29
-4
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
1212

13+
init_iq4nl_shmem();
14+
1315
const uint tid = gl_LocalInvocationID.x % 64;
1416
const uint il = tid/32;
1517
const uint ir = tid%32;

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15+
#if defined(DATA_A_IQ4_NL)
16+
init_iq4nl_shmem();
17+
#endif
18+
1519
if (i00 >= p.ne00) {
1620
return;
1721
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
102102
void main() {
103103
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
104104

105+
#if defined(DATA_A_IQ4_NL)
106+
init_iq4nl_shmem();
107+
#endif
108+
105109
// do NUM_ROWS at a time, unless there aren't enough remaining rows
106110
if (first_row + NUM_ROWS <= p.stride_d) {
107111
compute_outputs(first_row, NUM_ROWS);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
7575
#endif
7676

7777
void main() {
78+
#if defined(DATA_A_IQ4_NL)
79+
init_iq4nl_shmem();
80+
#endif
81+
7882
#ifdef MUL_MAT_ID
7983
const uint expert_idx = gl_GlobalInvocationID.z;
8084
#else

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,19 @@ struct block_iq4_nl
193193

194194
#define A_TYPE block_iq4_nl
195195

196-
const int8_t kvalues_iq4nl[16] = {
196+
const int8_t kvalues_iq4nl_const[16] = {
197197
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
198198
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
199199
};
200+
201+
shared FLOAT_TYPE kvalues_iq4nl[16];
202+
203+
void init_iq4nl_shmem()
204+
{
205+
// copy the table into shared memory and sync
206+
if (gl_LocalInvocationIndex.x < 16) {
207+
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
208+
}
209+
barrier();
210+
}
200211
#endif

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ void process_shaders() {
331331
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
332332

333333
if (tname == "f16") {
334-
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
334+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
335335
} else {
336-
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
336+
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
337337
}
338-
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
338+
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
339339
}
340340
}
341341

0 commit comments

Comments
 (0)