Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 95 additions & 55 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};

Expand Down Expand Up @@ -243,74 +250,100 @@ void main() {
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;

const uint ib = idx / 16;
const uint iqs = idx & 0xF;

const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;

const uint ib = idx / 4;
const uint iqs = idx & 0x03;

const float d = float(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;

buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;

const uint ib = idx / 16;
const uint iqs = idx & 0xF;

const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;

const uint ib = idx / 4;
const uint iqs = idx & 0x03;

const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;

buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;

const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;

const float d = float(data_a[ib].d);
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
const float d = float(data_a_packed16[ib].d);
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);

const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;

const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;

const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint uint_qh = data_a[ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint uint_qh = data_a_packed16[ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);

const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;

const float d = float(data_a[ib].d);
const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
const float d = float(data_a_packed16[ib].d);
const uint v0 = uint(data_a_packed16[ib].qs[2*iqs]);
const uint v1 = uint(data_a_packed16[ib].qs[2*iqs + 1]);
const vec4 v = vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)) * d;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
Expand Down Expand Up @@ -623,17 +656,24 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;

const uint ib = idx / 16;
const uint iqs = idx & 0xF;

const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;

const uint ib = idx / 4;
const uint iqs = idx & 0x03;

const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const u8vec4 v0 = unpack8(vui & 0x0F0F0F0F);
const u8vec4 v1 = unpack8((vui >> 4) & 0x0F0F0F0F);

buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[v0.x]) * d;
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[v0.y]) * d;
buf_a[buf_idx + 2 ] = FLOAT_TYPE(kvalues_iq4nl[v0.z]) * d;
buf_a[buf_idx + 3 ] = FLOAT_TYPE(kvalues_iq4nl[v0.w]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[v1.x]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[v1.y]) * d;
buf_a[buf_idx + 18] = FLOAT_TYPE(kvalues_iq4nl[v1.z]) * d;
buf_a[buf_idx + 19] = FLOAT_TYPE(kvalues_iq4nl[v1.w]) * d;
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
Expand Down
10 changes: 8 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);

for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq4_nl"))
load_vec_quant = "8";
if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0"))
load_vec_quant = "4";

std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;

// don't generate f32 variants for coopmat2
if (!coopmat2) {
Expand Down
Loading