Skip to content

Commit 924bccc

Browse files
committed
vulkan: support copy from f32 to q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl
Shaders are based on cpy.cu.
1 parent 1204f97 commit 924bccc

File tree

4 files changed

+287
-5
lines changed

4 files changed

+287
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ struct vk_device_struct {
228228
vk_pipeline pipeline_repeat_f32;
229229
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
230230
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
231+
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
231232
vk_pipeline pipeline_norm_f32;
232233
vk_pipeline pipeline_group_norm_f32;
233234
vk_pipeline pipeline_rms_norm_f32;
@@ -1965,6 +1966,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651966
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
19661967
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
19671968

1969+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
1970+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
1971+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
1972+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
1973+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
1974+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
1975+
19681976
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
19691977
ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
19701978
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -3689,6 +3697,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
36893697
return ctx->device->pipeline_cpy_f16_f16;
36903698
}
36913699
}
3700+
if (src->type == GGML_TYPE_F32) {
3701+
switch (to) {
3702+
case GGML_TYPE_Q4_0:
3703+
case GGML_TYPE_Q4_1:
3704+
case GGML_TYPE_Q5_0:
3705+
case GGML_TYPE_Q5_1:
3706+
case GGML_TYPE_Q8_0:
3707+
case GGML_TYPE_IQ4_NL:
3708+
return ctx->device->pipeline_cpy_f32_quant[to];
3709+
default:
3710+
break;
3711+
}
3712+
}
36923713

36933714
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
36943715
GGML_ABORT("fatal error");
@@ -7905,11 +7926,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
79057926
{
79067927
ggml_type src0_type = op->src[0]->type;
79077928
ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
7908-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
7909-
return true;
7910-
}
7911-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
7912-
return true;
7929+
7930+
if (src0_type == GGML_TYPE_F32) {
7931+
switch (src1_type) {
7932+
case GGML_TYPE_F32:
7933+
case GGML_TYPE_F16:
7934+
case GGML_TYPE_Q4_0:
7935+
case GGML_TYPE_Q4_1:
7936+
case GGML_TYPE_Q5_0:
7937+
case GGML_TYPE_Q5_1:
7938+
case GGML_TYPE_Q8_0:
7939+
case GGML_TYPE_IQ4_NL:
7940+
return true;
7941+
default:
7942+
break;
7943+
}
79137944
}
79147945
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
79157946
return true;
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
#if defined(DATA_A_IQ4_NL)
7+
// 16 invocations needed for init_iq4nl_shmem
8+
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
9+
#else
10+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
11+
#endif
12+
13+
layout (binding = 0) readonly buffer S {float data_s[];};
14+
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
15+
16+
#if defined(DATA_A_Q4_0)
17+
void quantize(uint dst_idx, uint src_idx)
18+
{
19+
float amax = 0.0;
20+
float vmax = 0.0;
21+
22+
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
23+
const float v = data_s[src_idx + j];
24+
if (amax < abs(v)) {
25+
amax = abs(v);
26+
vmax = v;
27+
}
28+
}
29+
30+
const float d = vmax / -8;
31+
const float id = (d != 0.0) ? 1.0/d : 0.0;
32+
33+
data_q[dst_idx].d = float16_t(d);
34+
35+
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
36+
const float x0 = data_s[src_idx + 0 + j]*id;
37+
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
38+
39+
const uint xi0 = min(15, int(x0 + 8.5));
40+
const uint xi1 = min(15, int(x1 + 8.5));
41+
42+
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
43+
}
44+
}
45+
#endif
46+
47+
#if defined(DATA_A_Q4_1)
48+
void quantize(uint dst_idx, uint src_idx)
49+
{
50+
float vmin = 1.0/0.0;
51+
float vmax = -vmin;
52+
53+
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
54+
const float v = data_s[src_idx + j];
55+
56+
if (v < vmin) vmin = v;
57+
if (v > vmax) vmax = v;
58+
}
59+
60+
const float d = (vmax - vmin) / ((1 << 4) - 1);
61+
const float id = (d != 0.0) ? 1.0/d : 0.0;
62+
63+
data_q[dst_idx].d = float16_t(d);
64+
data_q[dst_idx].m = float16_t(vmin);
65+
66+
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
67+
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
68+
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
69+
70+
const uint xi0 = min(15, int(x0 + 0.5));
71+
const uint xi1 = min(15, int(x1 + 0.5));
72+
73+
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
74+
}
75+
}
76+
#endif
77+
78+
#if defined(DATA_A_Q5_0)
79+
void quantize(uint dst_idx, uint src_idx)
80+
{
81+
float amax = 0.0;
82+
float vmax = 0.0;
83+
84+
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
85+
const float v = data_s[src_idx + j];
86+
if (amax < abs(v)) {
87+
amax = abs(v);
88+
vmax = v;
89+
}
90+
}
91+
92+
const float d = vmax / -16;
93+
const float id = (d != 0.0) ? 1.0/d : 0.0;
94+
95+
data_q[dst_idx].d = float16_t(d);
96+
97+
uint32_t qh = 0;
98+
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
99+
const float x0 = data_s[src_idx + 0 + j]*id;
100+
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
101+
102+
const uint xi0 = min(31, int(x0 + 16.5));
103+
const uint xi1 = min(31, int(x1 + 16.5));
104+
105+
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
106+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
107+
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
108+
}
109+
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
110+
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
111+
}
112+
#endif
113+
114+
#if defined(DATA_A_Q5_1)
115+
void quantize(uint dst_idx, uint src_idx)
116+
{
117+
float min = data_s[src_idx + 0];
118+
float max = min;
119+
120+
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
121+
const float v = data_s[src_idx + j];
122+
min = v < min ? v : min;
123+
max = v > max ? v : max;
124+
}
125+
126+
const float d = (max - min) / 31;
127+
const float id = (d != 0) ? 1.0/d : 0.0;
128+
129+
data_q[dst_idx].d = float16_t(d);
130+
data_q[dst_idx].m = float16_t(min);
131+
132+
uint32_t qh = 0;
133+
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
134+
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
135+
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
136+
137+
const uint xi0 = uint(x0 + 0.5);
138+
const uint xi1 = uint(x1 + 0.5);
139+
140+
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
141+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
142+
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
143+
}
144+
data_q[dst_idx].qh = qh;
145+
}
146+
#endif
147+
148+
#if defined(DATA_A_Q8_0)
149+
void quantize(uint dst_idx, uint src_idx)
150+
{
151+
float amax = 0.0; // absolute max
152+
153+
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
154+
const float v = data_s[src_idx + j];
155+
amax = max(amax, abs(v));
156+
}
157+
158+
const float d = amax / ((1 << 7) - 1);
159+
const float id = (d != 0.0) ? 1.0/d : 0.0;
160+
161+
data_q[dst_idx].d = float16_t(d);
162+
163+
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
164+
const float x0 = data_s[src_idx + j]*id;
165+
166+
data_q[dst_idx].qs[j] = int8_t(round(x0));
167+
}
168+
}
169+
#endif
170+
171+
#if defined(DATA_A_IQ4_NL)
172+
uint best_index(float x) {
173+
if (x <= kvalues_iq4nl[0]) return 0;
174+
if (x >= kvalues_iq4nl[15]) return 15;
175+
int ml = 0, mu = 15;
176+
while (mu-ml > 1) {
177+
int mav = (ml+mu)/2;
178+
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
179+
}
180+
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
181+
}
182+
183+
void quantize(uint dst_idx, uint src_idx)
184+
{
185+
float amax = 0.0;
186+
float vmax = 0.0;
187+
188+
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
189+
const float v = data_s[src_idx + j];
190+
if (amax < abs(v)) {
191+
amax = abs(v);
192+
vmax = v;
193+
}
194+
}
195+
196+
float d = vmax / kvalues_iq4nl[0];
197+
const float id = (d != 0.0) ? 1.0/d : 0.0;
198+
199+
float sumqx = 0, sumq2 = 0;
200+
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
201+
const float x0 = data_s[src_idx + 0 + j]*id;
202+
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
203+
const uint xi0 = best_index(x0);
204+
const uint xi1 = best_index(x1);
205+
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
206+
const float v0 = kvalues_iq4nl[xi0];
207+
const float v1 = kvalues_iq4nl[xi1];
208+
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
209+
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
210+
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
211+
sumq2 += w0*v0*v0 + w1*v1*v1;
212+
}
213+
214+
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
215+
216+
}
217+
#endif
218+
219+
void main() {
220+
#if defined(DATA_A_IQ4_NL)
221+
init_iq4nl_shmem();
222+
if (gl_LocalInvocationIndex.x != 0) {
223+
return;
224+
}
225+
#endif
226+
227+
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
228+
229+
if (idx >= p.ne) {
230+
return;
231+
}
232+
233+
uint dst_idx = dst_idx_quant(idx, QUANT_K);
234+
uint src_idx = get_aoffset() + src0_idx(idx);
235+
236+
quantize(dst_idx, src_idx);
237+
}

ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,13 @@ uint dst_idx(uint idx) {
5454
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
5555
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
5656
}
57+
58+
uint dst_idx_quant(uint idx, uint qk) {
59+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
60+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
61+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
62+
const uint i12_offset = i12*p.ne11*p.ne10;
63+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
64+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
65+
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
66+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,10 @@ void process_shaders() {
419419
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
420420
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
421421

422+
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
423+
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
424+
}
425+
422426
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
423427
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
424428

0 commit comments

Comments
 (0)