|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +#define PRECISION ${PRECISION} |
| 12 | + |
| 13 | +${define_required_extensions("uint8")} |
| 14 | +${define_required_extensions("int8")} |
| 15 | + |
| 16 | +layout(std430) buffer; |
| 17 | + |
| 18 | +${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} |
| 19 | +${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")} |
| 20 | + |
| 21 | +layout(push_constant) uniform restrict Block { |
| 22 | + ivec4 qmat2_sizes; |
| 23 | +}; |
| 24 | + |
| 25 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 26 | + |
| 27 | +uint8_t get_first(const uint8_t packed) { |
| 28 | + return uint8_t((packed & 0xF0) >> 4); |
| 29 | +} |
| 30 | + |
| 31 | +uint8_t get_second(const uint8_t packed) { |
| 32 | + return uint8_t(packed & 0x0F); |
| 33 | +} |
| 34 | + |
| 35 | +uint8_t combine(const uint8_t first, const uint8_t second) { |
| 36 | + return uint8_t(first << 4 | second); |
| 37 | +} |
| 38 | + |
| 39 | +/* |
| 40 | + * This shader packs the weight tensor into a texture. |
| 41 | + * |
| 42 | + * The original tensor has a (W, H) shape of (K / 2, N) and each scalar element |
| 43 | + * is a uint8_t, which contains 2 packed 4 bit uint values. |
| 44 | + * |
| 45 | + * The transform performed by this shader is to first transpose the tensor, so |
| 46 | + * the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers |
| 47 | + * are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits |
| 48 | + * of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of |
| 49 | + * each value contain the 4, 5, 6, 7 4-bit values. |
| 50 | + * |
| 51 | + * As a concrete example, consider the following weight tensor. The | demarks |
| 52 | + * the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the |
| 53 | + * leftmost 4 bits and 2 in the rightmost 4 bits. |
| 54 | + * |
| 55 | + * 1| 2, 3| 4, 5| 6, 7| 8, |
| 56 | + * 9|10, 11|12, 13|14, 15|16, |
| 57 | + * 17|18, 19|20, 21|22, 23|24, |
| 58 | + * 25|26, 27|28, 29|30, 31|32, |
| 59 | + * 33|34, 35|36, 37|38, 39|40, |
| 60 | + * 41|42, 43|44, 45|46, 47|48, |
| 61 | + * 49|50, 51|52, 53|54, 55|56, |
| 62 | + * 57|58, 59|60, 61|62, 63|64, |
| 63 | + * |
| 64 | + * After packing, the packed tensor would contain |
| 65 | + * |
| 66 | + * 1|33, 9|41, 17|49, 25|57, |
| 67 | + * 2|34, 10|42, 18|50, 26|58, |
| 68 | + * 3|35, 11|43, 19|51, 27|59, |
| 69 | + * 4|36, 12|44, 20|52, 28|60, |
| 70 | + * 5|37, 13|45, 21|53, 29|61, |
| 71 | + * 6|38, 14|46, 22|54, 30|62, |
| 72 | + * 7|39, 15|47, 23|55, 31|63, |
| 73 | + * 8|40, 16|48, 24|56, 32|64, |
| 74 | + * |
| 75 | + * The purpose of interleaving is to make it easier to extract the unpacked |
| 76 | + * values in order using the u8vec4 vectorized type. With the packing in place, |
| 77 | + * The 4-bit values can be extracted via |
| 78 | + * |
| 79 | + * u8vec4 packed; |
| 80 | + * u8vec4 vals_0123 = (packed & 0xF0) >> 4; |
| 81 | + * u8vec4 vals_4567 = (packed | 0x0F); |
| 82 | + */ |
| 83 | +void main() { |
| 84 | + // Each thread writes 2 output texels along the height axis |
| 85 | + ivec2 packed_pos = ivec2( |
| 86 | + gl_GlobalInvocationID.x, |
| 87 | + gl_GlobalInvocationID.y << 1); |
| 88 | + |
| 89 | + // The packed tensor is width packed |
| 90 | + if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) { |
| 91 | + return; |
| 92 | + } |
| 93 | + |
| 94 | + int out_col = packed_pos.x << 3; |
| 95 | + int out_row = packed_pos.y; |
| 96 | + |
| 97 | + int in_col = out_row; |
| 98 | + int in_int8_col = in_col >> 1; |
| 99 | + int in_row = out_col; |
| 100 | + |
| 101 | + int in_numrows = qmat2_sizes.x << 1; |
| 102 | + int in_numcols = qmat2_sizes.y; |
| 103 | + int in_num_int8_cols = qmat2_sizes.y >> 1; |
| 104 | + |
| 105 | + uint8_t in_vals[8][2]; |
| 106 | + for (int r = 0; r < 8; ++r) { |
| 107 | + if (in_row + r < in_numrows) { |
| 108 | + uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col]; |
| 109 | + in_vals[r][0] = get_first(in_val_packed); |
| 110 | + in_vals[r][1] = get_second(in_val_packed); |
| 111 | + } else { |
| 112 | + in_vals[r][0] = uint8_t(254); |
| 113 | + in_vals[r][1] = uint8_t(254); |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + u8vec4 out_tex_1 = u8vec4( |
| 118 | + combine(in_vals[0][0], in_vals[4][0]), |
| 119 | + combine(in_vals[1][0], in_vals[5][0]), |
| 120 | + combine(in_vals[2][0], in_vals[6][0]), |
| 121 | + combine(in_vals[3][0], in_vals[7][0])); |
| 122 | + |
| 123 | + u8vec4 out_tex_2 = u8vec4( |
| 124 | + combine(in_vals[0][1], in_vals[4][1]), |
| 125 | + combine(in_vals[1][1], in_vals[5][1]), |
| 126 | + combine(in_vals[2][1], in_vals[6][1]), |
| 127 | + combine(in_vals[3][1], in_vals[7][1])); |
| 128 | + |
| 129 | + $if STORAGE == "buffer": |
| 130 | + int stride = qmat2_sizes.x >> 2; |
| 131 | + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; |
| 132 | + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; |
| 133 | + $else: |
| 134 | + imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1); |
| 135 | + imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2); |
| 136 | +} |
0 commit comments