Skip to content

Commit 2ac1d0e

Browse files
committed
kompute: op_mul_mat_q4_k permutted support
Signed-off-by: Sergio Lopez <[email protected]>
1 parent 1b8afa8 commit 2ac1d0e

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,19 +1079,26 @@ static void ggml_vk_mul_mat_q4_k(
10791079
const std::shared_ptr<kp::Tensor>& inB,
10801080
const std::shared_ptr<kp::Tensor>& out,
10811081
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1082-
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1083-
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1084-
int32_t ne1, int32_t r2, int32_t r3
1082+
int32_t ne00, int32_t ne01, int32_t ne02,
1083+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1084+
int32_t ne0, int32_t ne1,
1085+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
1086+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
1087+
uint32_t r2, uint32_t r3
10851088
) {
10861089
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
10871090
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
10881091

10891092
struct PushConstants {
10901093
uint32_t inAOff, inBOff, outOff;
1091-
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1094+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1095+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1096+
uint32_t r2, r3;
10921097
} pushConsts {
1093-
0, 0, 0,
1094-
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1098+
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1099+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1100+
nb01, nb02, nb03, nb11, nb12, nb13,
1101+
r2, r3
10951102
};
10961103

10971104
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
@@ -1705,7 +1712,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17051712
case GGML_TYPE_Q4_K:
17061713
ggml_vk_mul_mat_q4_k(
17071714
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1708-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1715+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1716+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17091717
);
17101718
break;
17111719
case GGML_TYPE_Q6_K:

ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
2424
int ne01;
2525
int ne02;
2626
int ne12;
27-
int r2;
28-
int r3;
27+
uint nb01;
28+
uint nb02;
29+
uint nb03;
30+
uint nb11;
31+
uint nb12;
32+
uint nb13;
33+
uint r2;
34+
uint r3;
2935
} pcs;
3036

3137
void main() {
@@ -50,10 +56,11 @@ void main() {
5056
const uint i12 = im%pcs.ne12;
5157
const uint i13 = im/pcs.ne12;
5258

53-
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
59+
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
60+
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
5461

55-
const uint xblk = ib_row + offset0 + pcs.inAOff;
56-
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
62+
const uint xblk = offset0 + pcs.inAOff;
63+
const uint y = (offset1 / 4) + pcs.inBOff;
5764

5865
float yl[16];
5966
float yh[16];
@@ -74,7 +81,7 @@ void main() {
7481
}
7582

7683
for (int row = 0; row < N_DST; row++) {
77-
uint row_idx = row * nb;
84+
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
7885

7986
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
8087
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);

0 commit comments

Comments
 (0)