|
| 1 | +#decl(SHMEM_VEC) |
| 2 | +fn store_shmem(val: vec4<f16>, idx: u32) { |
| 3 | + shmem[idx] = val.x; |
| 4 | + shmem[idx + 1] = val.y; |
| 5 | + shmem[idx + 2] = val.z; |
| 6 | + shmem[idx + 3] = val.w; |
| 7 | +} |
| 8 | +#enddecl(SHMEM_VEC) |
| 9 | + |
| 10 | +#decl(SHMEM_SCALAR) |
| 11 | +fn store_shmem(val: f16, idx: u32) { |
| 12 | + shmem[idx] = val; |
| 13 | +} |
| 14 | +#enddecl(SHMEM_SCALAR) |
| 15 | + |
| 16 | +#decl(INIT_SRC0_SHMEM_FLOAT) |
| 17 | + |
| 18 | +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { |
| 19 | + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { |
| 20 | + let tile_m = elem_idx / TILE_K; |
| 21 | + let tile_k = elem_idx % TILE_K; |
| 22 | + let global_m = offset_m + tile_m; |
| 23 | + let global_k = k_outer + tile_k; |
| 24 | + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; |
| 25 | + let src0_val = select( // taking a slight performance hit to avoid oob |
| 26 | + {{SRC0_TYPE}}(0.0), |
| 27 | + src0[src0_idx/{{VEC_SIZE}}], |
| 28 | + global_m < params.m && global_k < params.k); |
| 29 | + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +#enddecl(INIT_SRC0_SHMEM_FLOAT) |
| 34 | + |
| 35 | +#decl(INIT_SRC1_SHMEM) |
| 36 | + |
| 37 | +fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { |
| 38 | + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { |
| 39 | + let tile_n = elem_idx / TILE_K; |
| 40 | + let tile_k = elem_idx % TILE_K; |
| 41 | + let global_n = offset_n + tile_n; |
| 42 | + let global_k = k_outer + tile_k; |
| 43 | + let src1_idx = batch_offset + global_n * params.stride_11 + global_k; |
| 44 | + let src1_val = select( |
| 45 | + {{SRC1_TYPE}}(0.0), |
| 46 | + src1[src1_idx/{{VEC_SIZE}}], |
| 47 | + global_n < params.n && global_k < params.k); |
| 48 | + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +#enddecl(INIT_SRC1_SHMEM) |
| 53 | + |
| 54 | +#decl(INIT_SRC0_SHMEM_Q4_0) |
| 55 | + |
| 56 | +const BLOCK_SIZE = 32u; |
| 57 | +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. |
| 58 | +override BLOCKS_K = TILE_K/BLOCK_SIZE; |
| 59 | +const NQ = 16u; |
| 60 | +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights |
| 61 | +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 |
| 62 | +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; |
| 63 | + |
| 64 | +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { |
| 65 | + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { |
| 66 | + let blck_idx = i / BLOCK_SIZE; |
| 67 | + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; |
| 68 | + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; |
| 69 | + |
| 70 | + let tile_m = blck_idx / BLOCKS_K; |
| 71 | + let global_m = offset_m + tile_m; |
| 72 | + let block_k = blck_idx % BLOCKS_K; |
| 73 | + let global_k = k_outer / BLOCK_SIZE + block_k; |
| 74 | + |
| 75 | + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { |
| 76 | + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; |
| 77 | + let scale_idx = src0_idx * F16_PER_BLOCK; |
| 78 | + let d = src0[scale_idx]; |
| 79 | + |
| 80 | + for (var j = 0u; j < F16_PER_THREAD; j += 2) { |
| 81 | + let q_0 = src0[scale_idx + 1u + block_offset + j]; |
| 82 | + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; |
| 83 | + |
| 84 | + let q_packed = bitcast<u32>(vec2(q_0, q_1)); |
| 85 | + for (var k = 0u; k < 4u; k++) { |
| 86 | + let q_byte = get_byte(q_packed, k); |
| 87 | + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; |
| 88 | + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; |
| 89 | + shmem[shmem_idx + j * 2 + k] = q_lo; |
| 90 | + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +#enddecl(INIT_SRC0_SHMEM_Q4_0) |
0 commit comments