@@ -10,7 +10,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1010layout (constant_id = 1) const uint NUM_ROWS = 1;
1111
1212shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
13- shared block_q6_K_packed16_flat blkcache[BLOCK_SIZE/16];
13+ shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16];
1414
1515void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1616 uint a_offset, b_offset, d_offset;
@@ -22,7 +22,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2222 const uint it_size = gl_WorkGroupSize.x/16;
2323 const uint tid = gl_LocalInvocationID.x;
2424 const uint itid = tid%16; // 0...15
25- const uint ix = tid/16;
25+ const uint ix = tid/16;
2626
2727 const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
2828 const uint v_in = itid - 8*v_im; // 0...15 or 0...7
@@ -58,24 +58,24 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5858 // cache full superblock into shared memory with coalesced reads
5959 // we assume 64 threads here!
6060 [[unroll]] for (int l = 0; (l < 4) && (i0 + l < num_blocks_per_row); ++l) {
61- blkcache[l].blkd [tid] = data_a_packed16_flat [ib0 + i0 + l].blkd [tid];
62- // we read beyond the struct size but it looks like vulkan doesn't care? o_O
61+ blkcache[l].ql [tid] = data_a_packed16 [ib0 + i0 + l].ql [tid];
62+ // hacky method of reading beyond ql and the block struct size but it looks like vulkan doesn't care? o_O
6363 // this assumes that the struct is packed in continous 16 bit blocks to work
64- blkcache[l].blkd [64 + tid] = data_a_packed16_flat [ib0 + i0 + l].blkd [64 + tid];
64+ blkcache[l].ql [64 + tid] = data_a_packed16 [ib0 + i0 + l].ql [64 + tid];
6565 }
6666 barrier();
6767 if (i >= num_blocks_per_row)
6868 continue;
6969
70- uint32_t ql0_u32 = uint32_t(blkcache[ix].blkd [ql_offset / 2]) | (uint32_t(blkcache[ix].blkd [ql_offset / 2 + 1]) << 16);
71- uint32_t ql32_u32 = uint32_t(blkcache[ix].blkd [ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].blkd [ql_offset / 2 + 17]) << 16);
70+ uint32_t ql0_u32 = uint32_t(blkcache[ix].ql [ql_offset / 2]) | (uint32_t(blkcache[ix].ql [ql_offset / 2 + 1]) << 16);
71+ uint32_t ql32_u32 = uint32_t(blkcache[ix].ql [ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].ql [ql_offset / 2 + 17]) << 16);
7272
7373 uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
7474 uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
7575 uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
7676 uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
7777
78- uint32_t qh_u32 = uint32_t(blkcache[ix].blkd[64 + qh_offset / 2]) | (uint32_t(blkcache[ix].blkd[64 + qh_offset / 2 + 1]) << 16);
78+ uint32_t qh_u32 = uint32_t(blkcache[ix].qh[ qh_offset / 2]) | (uint32_t(blkcache[ix].qh[ qh_offset / 2 + 1]) << 16);
7979 uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
8080 uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
8181 uint32_t qh4_u32 = (qh_u32 & 0x30303030);
0 commit comments