Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ void main() {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;

const uint8_t hm1 = uint8_t(1 << (2*v_im));
const uint8_t hm2 = uint8_t(hm1 << 4);

FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp

[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
Expand Down Expand Up @@ -71,6 +68,18 @@ void main() {
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;

uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));

uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;

qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;

uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
Expand Down Expand Up @@ -102,31 +111,26 @@ void main() {
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];

uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
uint32_t qh1 = qh0 >> 8;
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
uint32_t qh17 = qh16 >> 8;

const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
Expand Down
Loading