Skip to content

Commit 0e3d85d

Browse files
committed
kompute: op_mul_mat_q6_k permutted support
Signed-off-by: Sergio Lopez <[email protected]>
1 parent f54c96e commit 0e3d85d

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
11261126
const std::shared_ptr<kp::Tensor>& inB,
11271127
const std::shared_ptr<kp::Tensor>& out,
11281128
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1129-
int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1130-
int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1129+
int32_t ne00, int32_t ne01, int32_t ne02,
1130+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1131+
int32_t ne0, int32_t ne1,
1132+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
1133+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
1134+
uint32_t r2, uint32_t r3
11311135
) {
11321136
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
11331137
kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
11341138

11351139
struct PushConstants {
11361140
uint32_t inAOff, inBOff, outOff;
1137-
int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1141+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1142+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1143+
uint32_t r2, r3;
11381144
} pushConsts {
11391145
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1140-
ne00, ne10, ne0, ne1, ne01, ne12/ne02
1146+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1147+
nb01, nb02, nb03, nb11, nb12, nb13,
1148+
r2, r3
11411149
};
11421150

11431151
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
11441152
if (!komputeManager()->hasAlgorithm(__func__)) {
1145-
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1146-
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1153+
const uint32_t local_x = 2;
1154+
const uint32_t local_y = ggml_vk_current_device().subgroupSize;
1155+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
11471156
} else {
11481157
s_algo = komputeManager()->getAlgorithm(__func__);
11491158
s_algo->setTensors({inA, inB, out});
1150-
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1159+
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
11511160
s_algo->setPushConstants<PushConstants>({pushConsts});
11521161
s_algo->updateDescriptors(s_kompute_context->pool.get());
11531162
}
@@ -1450,8 +1459,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
14501459

14511460
switch (op->src[0]->type) {
14521461
case GGML_TYPE_F32:
1453-
case GGML_TYPE_Q6_K:
14541462
return op->ne[3] == 1;
1463+
case GGML_TYPE_Q6_K:
14551464
case GGML_TYPE_F16:
14561465
case GGML_TYPE_Q8_0:
14571466
case GGML_TYPE_Q4_0:
@@ -1729,7 +1738,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17291738
case GGML_TYPE_Q6_K:
17301739
ggml_vk_mul_mat_q6_k(
17311740
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1732-
ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1741+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1742+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17331743
);
17341744
break;
17351745
default: {

ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
2121
int ne0;
2222
int ne1;
2323
int ne01;
24-
int gqa;
24+
int ne02;
25+
int ne12;
26+
uint nb01;
27+
uint nb02;
28+
uint nb03;
29+
uint nb11;
30+
uint nb12;
31+
uint nb13;
32+
uint r2;
33+
uint r3;
2534
} pcs;
2635

2736
void main() {
@@ -34,12 +43,15 @@ void main() {
3443

3544
const uint r0 = gl_WorkGroupID.x;
3645
const uint r1 = gl_WorkGroupID.y;
37-
const uint r2 = gl_WorkGroupID.z;
46+
const uint im = gl_WorkGroupID.z;
3847

3948
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
40-
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
41-
const uint x = row * nb + offset0; // Based from inA without base offset
42-
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
49+
50+
const uint i12 = im%pcs.ne12;
51+
const uint i13 = im/pcs.ne12;
52+
53+
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
54+
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
4355

4456
float sumf = 0;
4557

@@ -89,6 +101,6 @@ void main() {
89101

90102
const float tot = subgroupAdd(sumf);
91103
if (subgroupElect()) {
92-
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
104+
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
93105
}
94106
}

0 commit comments

Comments
 (0)