Skip to content

Commit 9c5bdf4

Browse files
committed
kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support
Signed-off-by: Sergio Lopez <[email protected]>
1 parent 2ac1d0e commit 9c5bdf4

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,26 +1018,32 @@ static void ggml_vk_mul_mat_impl(
10181018
int32_t ne00, int32_t ne01, int32_t ne02,
10191019
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
10201020
int32_t ne0, int32_t ne1,
1021+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
1022+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
10211023
uint32_t r2, uint32_t r3
10221024
) {
10231025
struct PushConstants {
10241026
uint32_t inAOff, inBOff, outOff;
10251027
int32_t ne00, ne01, ne02;
10261028
int32_t ne10, ne12;
10271029
int32_t ne0, ne1;
1030+
uint32_t nb01, nb02, nb03;
1031+
uint32_t nb11, nb12, nb13;
10281032
uint32_t r2, r3;
10291033
} pushConsts {
10301034
safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
10311035
ne00, ne01, ne02,
10321036
ne10, ne12,
10331037
ne0, ne1,
1038+
nb01, nb02, nb03,
1039+
nb11, nb12, nb13,
10341040
r2, r3
10351041
};
10361042

10371043
auto name = std::string(__func__) + "_" + suffix;
10381044
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
10391045
if (!komputeManager()->hasAlgorithm(name)) {
1040-
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1046+
const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
10411047
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
10421048
} else {
10431049
s_algo = komputeManager()->getAlgorithm(name);
@@ -1694,19 +1700,22 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16941700
case GGML_TYPE_Q8_0:
16951701
ggml_vk_mul_mat_q8_0(
16961702
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1697-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1703+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1704+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
16981705
);
16991706
break;
17001707
case GGML_TYPE_Q4_0:
17011708
ggml_vk_mul_mat_q4_0(
17021709
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1703-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1710+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1711+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17041712
);
17051713
break;
17061714
case GGML_TYPE_Q4_1:
17071715
ggml_vk_mul_mat_q4_1(
17081716
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1709-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1717+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1718+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17101719
);
17111720
break;
17121721
case GGML_TYPE_Q4_K:

ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ void main() {
1414
const uint i12 = im%pcs.ne12;
1515
const uint i13 = im/pcs.ne12;
1616

17-
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
17+
// pointers to src0 rows
18+
uint ax[N_ROWS];
19+
for (int row = 0; row < N_ROWS; ++row) {
20+
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
21+
22+
ax[row] = offset0 + pcs.inAOff;
23+
}
1824

19-
const uint x = offset0; // Based from inA without base offset
20-
const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
25+
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
2126

2227
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
2328

@@ -32,8 +37,7 @@ void main() {
3237

3338
for (uint ib = ix; ib < nb; ib += 16) {
3439
for (int row = 0; row < N_ROWS; row++) {
35-
const uint block_index = x + ib + row * nb;
36-
sumf[row] += block_q_n_dot_y(block_index, yb, il);
40+
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
3741
}
3842

3943
yb += BLOCKS_IN_QUANT * 16;

ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
layout(local_size_x_id = 0) in;
2-
layout(local_size_y = 1) in;
2+
layout(local_size_y = 8) in;
33
layout(local_size_z = 1) in;
44

55
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
@@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
1717
int ne12;
1818
int ne0;
1919
int ne1;
20+
uint nb01;
21+
uint nb02;
22+
uint nb03;
23+
uint nb11;
24+
uint nb12;
25+
uint nb13;
2026
uint r2;
2127
uint r3;
2228
} pcs;

0 commit comments

Comments
 (0)