Skip to content

Commit 301332a

Browse files
committed
vulkan: enable the use of simpler matmul shaders
Import simpler matmul shaders from the kompute backend and use them on GPUs know to not be able to use the regular ones. Signed-off-by: Sergio Lopez <[email protected]>
1 parent 7fee288 commit 301332a

12 files changed

+941
-8
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 234 additions & 8 deletions
Large diffs are not rendered by default.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#extension GL_EXT_shader_16bit_storage: require
2+
#extension GL_EXT_shader_8bit_storage: require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
4+
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
5+
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
6+
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
7+
#extension GL_EXT_control_flow_attributes: enable
8+
#extension GL_KHR_shader_subgroup_arithmetic : require
9+
#extension GL_EXT_debug_printf : enable
10+
11+
#define QK4_0 32
12+
#define QK4_1 32
13+
14+
#define GELU_COEF_A 0.044715
15+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
16+
#define TWOPI_F 6.283185307179586f
17+
18+
#define QK_K 256
19+
#define K_SCALE_SIZE 12
20+
21+
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
22+
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
23+
#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
24+
#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
25+
26+
#define sizeof_block_q4_0 0x12
27+
struct block_q4_0 {
28+
float16_t d;
29+
uint8_t qs[QK4_0 / 2];
30+
};
31+
mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
32+
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
33+
const float d2 = d1 / 256.f;
34+
const float md = -8.f * xb.d;
35+
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
36+
const uint16_t mask1 = mask0 << 8;
37+
38+
mat4 reg;
39+
for (int i=0;i<8;i++) {
40+
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
41+
reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
42+
reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
43+
}
44+
return reg;
45+
}
46+
47+
#define sizeof_block_q4_1 0x14
48+
struct block_q4_1 {
49+
float16_t d;
50+
float16_t m;
51+
uint8_t qs[QK4_1 / 2];
52+
};
53+
mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
54+
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
55+
const float d2 = d1 / 256.f;
56+
const float m = xb.m;
57+
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
58+
const uint16_t mask1 = mask0 << 8;
59+
60+
mat4 reg;
61+
for (int i=0;i<8;i++) {
62+
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
63+
reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
64+
reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
65+
}
66+
return reg;
67+
}
68+
69+
#define sizeof_block_q4_k 144
70+
struct block_q4_k {
71+
float16_t d;
72+
float16_t dmin;
73+
uint8_t scales[K_SCALE_SIZE];
74+
uint8_t qs[QK_K/2];
75+
};
76+
77+
#define sizeof_block_q6_k 210
78+
struct block_q6_k {
79+
uint8_t ql[QK_K/2]; // quants, lower 4 bits
80+
uint8_t qh[QK_K/4]; // quants, upper 2 bits
81+
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
82+
float16_t d; // super-block scale
83+
};
84+
mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
85+
const float16_t d_all = xb.d;
86+
87+
const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
88+
const uint qhIndex = 32*(il/8) + 16*(il&1);
89+
float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
90+
il = (il/2) & 3;
91+
92+
const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
93+
const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
94+
const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
95+
const float16_t ml = float16_t(d_all * sc * 32.f);
96+
const float16_t dl = float16_t(d_all * sc * coef);
97+
mat4 reg;
98+
for (int i = 0; i < 16; ++i) {
99+
const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
100+
: ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
101+
reg[i/4][i%4] = dl * q - ml;
102+
}
103+
return reg;
104+
}
105+
106+
107+
#define QK8_0 32
108+
// struct block_q8_0 {
109+
// float16_t d; // delta
110+
// int8_t qs[QK8_0]; // quants
111+
// };
112+
#define sizeof_block_q8_0 34
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#version 450
2+
3+
#include "simpler_common.comp"
4+
5+
#extension GL_KHR_shader_subgroup_arithmetic : require
6+
7+
layout(local_size_x_id = 0) in;
8+
9+
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
10+
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
11+
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
12+
13+
layout (push_constant) uniform parameter {
14+
uint inAOff;
15+
uint inBOff;
16+
uint outOff;
17+
int ne00;
18+
int ne01;
19+
int ne02;
20+
uint nb00;
21+
uint nb01;
22+
uint nb02;
23+
uint nb03;
24+
int ne10;
25+
int ne11;
26+
int ne12;
27+
uint nb10;
28+
uint nb11;
29+
uint nb12;
30+
uint nb13;
31+
int ne0;
32+
int ne1;
33+
uint r2;
34+
uint r3;
35+
} pcs;
36+
37+
#define N_F16_F32 4
38+
39+
void main() {
40+
const uint r0 = gl_WorkGroupID.x;
41+
const uint rb = gl_WorkGroupID.y*N_F16_F32;
42+
const uint im = gl_WorkGroupID.z;
43+
44+
const uint i12 = im%pcs.ne12;
45+
const uint i13 = im/pcs.ne12;
46+
47+
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
48+
49+
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
50+
51+
for (uint row = 0; row < N_F16_F32; ++row) {
52+
uint r1 = rb + row;
53+
if (r1 >= pcs.ne11) {
54+
break;
55+
}
56+
57+
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
58+
59+
float sumf = 0;
60+
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
61+
sumf += float(inA[x+i]) * float(inB[y+i]);
62+
}
63+
64+
const float all_sum = subgroupAdd(sumf);
65+
if (subgroupElect()) {
66+
out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
67+
}
68+
}
69+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#version 450
2+
3+
#include "simpler_common.comp"
4+
5+
#extension GL_KHR_shader_subgroup_arithmetic : require
6+
#extension GL_EXT_debug_printf : enable
7+
8+
// device subgroup size
9+
layout (local_size_x_id = 0) in;
10+
11+
layout(binding = 0) readonly buffer tensorInA { float inA[]; };
12+
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
13+
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
14+
15+
layout(push_constant) uniform parameter {
16+
uint inAOff;
17+
uint inBOff;
18+
uint outOff;
19+
int ne00;
20+
int ne01;
21+
int ne02;
22+
int ne11;
23+
int ne12;
24+
uint nb01;
25+
uint nb02;
26+
uint nb11;
27+
uint nb12;
28+
uint nb1;
29+
uint nb2;
30+
}
31+
pcs;
32+
33+
34+
void main() {
35+
uvec3 gid = gl_WorkGroupID;
36+
37+
uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
38+
uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
39+
40+
const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
41+
const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
42+
float sum = 0.0f;
43+
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
44+
sum += float(inA[x+i]) * float(inB[y+i]);
45+
}
46+
47+
const float all_sum = subgroupAdd(sum);
48+
if (subgroupElect()) {
49+
out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
50+
}
51+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#version 450
2+
3+
#include "simpler_common.comp"
4+
5+
#define BLOCKS_IN_QUANT QK4_0
6+
#define SIZE_OF_BLOCK sizeof_block_q4_0
7+
#define N_ROWS 4
8+
9+
#include "simpler_mul_mv_q_n_pre.comp"
10+
11+
// The q4_0 version of this function
12+
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
13+
vec2 acc = vec2(0.0, 0.0);
14+
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
15+
float d = float(u8BufToFloat16(inA, index));
16+
float sumy = 0.0f;
17+
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
18+
const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
19+
20+
const float yl0 = inB[yb + i];
21+
const float yl1 = inB[yb + i + 1];
22+
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
23+
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
24+
25+
sumy += yl0 + yl1 + yl8 + yl9;
26+
27+
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
28+
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
29+
}
30+
return d * (sumy * -8.f + acc[0] + acc[1]);
31+
}
32+
33+
#include "simpler_mul_mv_q_n.comp"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#version 450
2+
3+
#include "simpler_common.comp"
4+
5+
#define BLOCKS_IN_QUANT QK4_1
6+
#define SIZE_OF_BLOCK sizeof_block_q4_1
7+
#define N_ROWS 4
8+
9+
#include "simpler_mul_mv_q_n_pre.comp"
10+
11+
// The q4_1 version of this function
12+
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
13+
vec2 acc = vec2(0.0, 0.0);
14+
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
15+
float d = float(u8BufToFloat16(inA, index));
16+
float m = float(u8BufToFloat16(inA, index+2));
17+
18+
float sumy = 0.0f;
19+
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
20+
const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
21+
22+
const float yl0 = inB[yb + i];
23+
const float yl1 = inB[yb + i + 1];
24+
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
25+
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
26+
27+
sumy += yl0 + yl1 + yl8 + yl9;
28+
29+
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
30+
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
31+
}
32+
return d * (acc[0] + acc[1]) + sumy * m;
33+
}
34+
35+
#include "simpler_mul_mv_q_n.comp"

0 commit comments

Comments
 (0)