-
Notifications
You must be signed in to change notification settings - Fork 13.7k
Vulkan: improve mul_mat_vec_iq1_m #16907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,35 +7,86 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | |
|
|
||
| FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||
|
|
||
| void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { | ||
| // ------------------ calc_superblock (final optimized version) ------------------ | ||
| void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, | ||
| const uint num_blocks_per_row, const uint first_row, const uint num_rows) { | ||
| // Compute starting index in matrix B for this superblock | ||
| const uint y_idx = i * QUANT_K + 32 * ib32; | ||
|
|
||
| uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; | ||
|
|
||
| // Precompute indices for quantization lookup tables | ||
| const uint qh_base = 2 * ib32; | ||
| const uint qs_base = 4 * ib32; | ||
| const uint sc_index = ib32 / 2; | ||
| const uint sc_shift = 6 * (ib32 & 1); | ||
|
|
||
| // Loop over rows in the superblock | ||
| [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
| // Load per-block scales and shift for quantization | ||
| const uint16_t[4] scales = data_a[ibi].scales; | ||
| const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; | ||
| const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); | ||
|
|
||
| const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); | ||
| const uint sc = data_a[ibi].scales[sc_index] >> sc_shift; | ||
|
|
||
| // Temporary caches for decoding | ||
| FLOAT_TYPE dl_cache[4]; | ||
| uint16_t gvf_cache[4]; | ||
| float delta_cache[4]; | ||
|
|
||
| // Precompute the multiplier and lookup values for 4 sub-blocks | ||
| [[unroll]] for (uint l = 0; l < 4; ++l) { | ||
| const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); | ||
| const uint qs = data_a[ibi].qs[4 * ib32 + l]; | ||
| const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; | ||
| const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); | ||
|
|
||
| const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); | ||
|
|
||
| [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||
| vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); | ||
| vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); | ||
|
|
||
| FLOAT_TYPE sum = FLOAT_TYPE(0.0); | ||
| [[unroll]] for (int k = 0; k < 4; ++k) { | ||
| sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, | ||
| fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); | ||
| } | ||
| temp[j][n] = fma(dl, sum, temp[j][n]); | ||
| dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1)); | ||
| const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1)); | ||
| const uint qs = data_a[ibi].qs[qs_base + l]; | ||
| gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)]; | ||
| delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; | ||
| } | ||
|
|
||
| // Loop over columns of the output | ||
| [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||
| // Compute base index for matrix B | ||
| const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4; | ||
| vec4 b_vals[8]; | ||
|
|
||
| // Load 8 vec4 values from matrix B | ||
| [[unroll]] for (int idx = 0; idx < 8; ++idx) { | ||
| b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]); | ||
| } | ||
|
|
||
| FLOAT_TYPE col_sum = FLOAT_TYPE(0.0); | ||
|
|
||
| // Loop over sub-blocks | ||
| [[unroll]] for (uint l = 0; l < 4; ++l) { | ||
| const uint16_t grid = gvf_cache[l]; | ||
| const float dl = dl_cache[l]; | ||
|
|
||
| // Decode 8 2-bit fbits from gvf_cache | ||
| float f0 = float(bitfieldExtract(grid, 0, 2)); | ||
| float f1 = float(bitfieldExtract(grid, 2, 2)); | ||
| float f2 = float(bitfieldExtract(grid, 4, 2)); | ||
| float f3 = float(bitfieldExtract(grid, 6, 2)); | ||
| float f4 = float(bitfieldExtract(grid, 8, 2)); | ||
| float f5 = float(bitfieldExtract(grid, 10, 2)); | ||
| float f6 = float(bitfieldExtract(grid, 12, 2)); | ||
| float f7 = float(bitfieldExtract(grid, 14, 2)); | ||
|
|
||
| // Pack into vec4 for vectorized FMA | ||
| const vec4 fbits_v0 = vec4(f0, f1, f2, f3); | ||
| const vec4 fbits_v1 = vec4(f4, f5, f6, f7); | ||
| const vec4 delta_v = vec4(delta_cache[l]); | ||
|
|
||
| // Vectorized fused multiply-add | ||
| vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0)); | ||
| sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v); | ||
|
|
||
| // Horizontal add to get scalar sum | ||
| FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w; | ||
|
|
||
| // Accumulate to column sum | ||
| col_sum = fma(dl, sum, col_sum); | ||
| } | ||
| // Write result to temporary buffer | ||
| temp[j][n] += col_sum; | ||
| } | ||
| ibi += num_blocks_per_row; | ||
| } | ||
|
|
@@ -44,39 +95,39 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, | |
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||
| uint a_offset, b_offset, d_offset; | ||
| get_offsets(a_offset, b_offset, d_offset); | ||
|
|
||
| const uint num_blocks_per_row = p.ncols / QUANT_K; | ||
|
|
||
| // 8 threads are used to process each block | ||
| const uint blocks_per_wg = gl_WorkGroupSize.x/8; | ||
| const uint blocks_per_wg = gl_WorkGroupSize.x / 8; | ||
| const uint tid = gl_LocalInvocationID.x; | ||
| const uint itid = tid % 8; // 0...7 | ||
| const uint itid = tid % 8; | ||
| const uint ix = tid / 8; | ||
|
|
||
| [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||
| [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||
| // Initialize temporary storage | ||
| [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) | ||
| [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) | ||
| temp[j][i] = FLOAT_TYPE(0); | ||
| } | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of the above changes in this function are just code style, please revert them. It's okay to improve readability and style of code you're touching anyways, but that doesn't apply here. I also prefer to keep curly brackets after loops or ifs. |
||
|
|
||
| // Loop over blocks assigned to this thread | ||
| [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) | ||
| calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); | ||
|
|
||
| // Reduce results from temporary buffer to output | ||
| reduce_result(temp, d_offset, first_row, num_rows, tid); | ||
| } | ||
|
|
||
| void main() { | ||
| // Compute first row for this workgroup | ||
| const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); | ||
|
|
||
| // Initialize shared memory for quantization lookups | ||
| init_iq_shmem(gl_WorkGroupSize); | ||
|
|
||
| // do NUM_ROWS at a time, unless there aren't enough remaining rows | ||
| if (first_row + NUM_ROWS <= p.stride_d) { | ||
| compute_outputs(first_row, NUM_ROWS); | ||
| } else { | ||
| if (first_row >= p.stride_d) { | ||
| return; | ||
| } | ||
| compute_outputs(first_row, p.stride_d - first_row); | ||
| } | ||
| // Early exit if out-of-bounds | ||
| if (first_row >= p.stride_d) | ||
| return; | ||
|
|
||
| // Number of rows to process for this workgroup | ||
| const uint rows_to_process = min(NUM_ROWS, p.stride_d - first_row); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty surprised if it helped to make the changes in this function - this will prevent the compiler from unrolling loops.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a difference from adding this, so I would prefer to keep it as it was. @lovedheart Can you benchmark if the changes in the main function make a difference for you? |
||
|
|
||
| // Compute outputs for assigned rows | ||
| compute_outputs(first_row, rows_to_process); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment isn't necessary.