1010
1111#define PRECISION ${PRECISION}
1212
13- ${define_required_extensions("uint8")}
14- ${define_required_extensions("int8")}
13+ $if not NO_INT8_BUFFERS:
14+ ${define_required_extensions("uint8")}
15+ $if STORAGE == "buffer ":
16+ ${define_required_extensions("int8")}
1517
1618layout (std430) buffer ;
1719
1820${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array= False)}
19- ${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer ")}
21+ $if NO_INT8_BUFFERS:
22+ ${layout_declare_tensor(B, "r", "nchw_4x2", "uint ", "buffer ")}
23+ $else :
24+ ${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer ")}
2025
2126layout (push_constant) uniform restrict Block {
2227 ivec4 qmat2_sizes;
2328};
2429
2530layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
2631
27- uint8_t get_first(const uint8_t packed) {
28- return uint8_t((packed & 0xF0) >> 4 );
32+ $if NO_INT8_BUFFERS:
33+ #define BUF_T uint
34+ $else :
35+ #define BUF_T uint8_t
36+
37+ $if STORAGE == "buffer ":
38+ #define UVEC4_T u8vec4
39+ $else :
40+ #define UVEC4_T uvec4
41+
42+ uint get_first(const BUF_T packed) {
43+ return (packed & 0xF0) >> 4 ;
2944}
3045
31- uint8_t get_second(const uint8_t packed) {
32- return uint8_t( packed & 0x0F) ;
46+ uint get_second(const BUF_T packed) {
47+ return packed & 0x0F;
3348}
3449
35- uint8_t combine(const uint8_t first, const uint8_t second) {
36- return uint8_t (first << 4 | second);
50+ uint combine(const uint first, const uint second) {
51+ return (first << 4 | second);
3752}
3853
39- /*
40- * This shader packs the weight tensor into a texture.
41- *
42- * The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
43- * is a uint8_t, which contains 2 packed 4 bit uint values.
44- *
45- * The transform performed by this shader is to first transpose the tensor, so
46- * the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
47- * are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
48- * of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
49- * each value contain the 4, 5, 6, 7 4-bit values.
50- *
51- * As a concrete example, consider the following weight tensor. The | demarks
52- * the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
53- * leftmost 4 bits and 2 in the rightmost 4 bits.
54- *
55- * 1| 2, 3| 4, 5| 6, 7| 8,
56- * 9|10, 11|12, 13|14, 15|16,
57- * 17|18, 19|20, 21|22, 23|24,
58- * 25|26, 27|28, 29|30, 31|32,
59- * 33|34, 35|36, 37|38, 39|40,
60- * 41|42, 43|44, 45|46, 47|48,
61- * 49|50, 51|52, 53|54, 55|56,
62- * 57|58, 59|60, 61|62, 63|64,
63- *
64- * After packing, the packed tensor would contain
65- *
66- * 1|33, 9|41, 17|49, 25|57,
67- * 2|34, 10|42, 18|50, 26|58,
68- * 3|35, 11|43, 19|51, 27|59,
69- * 4|36, 12|44, 20|52, 28|60,
70- * 5|37, 13|45, 21|53, 29|61,
71- * 6|38, 14|46, 22|54, 30|62,
72- * 7|39, 15|47, 23|55, 31|63,
73- * 8|40, 16|48, 24|56, 32|64,
74- *
75- * The purpose of interleaving is to make it easier to extract the unpacked
76- * values in order using the u8vec4 vectorized type. With the packing in place,
77- * The 4-bit values can be extracted via
78- *
79- * u8vec4 packed;
80- * u8vec4 vals_0123 = (packed & 0xF0) >> 4;
81- * u8vec4 vals_4567 = (packed | 0x0F);
82- */
54+ $if NO_INT8_BUFFERS:
55+ uint extract_comp(const uint packed4, const uint idx) {
56+ return (packed4 >> (idx * 8 )) & 0xFF;
57+ }
58+
8359void main() {
8460 // Each thread writes 2 output texels along the height axis
8561 ivec2 packed_pos = ivec2 (
@@ -102,25 +78,32 @@ void main() {
10278 int in_numcols = qmat2_sizes.y;
10379 int in_num_int8_cols = qmat2_sizes.y >> 1 ;
10480
105- uint8_t in_vals[8 ][2 ];
81+ uint in_vals[8 ][2 ];
10682 for (int r = 0 ; r < 8 ; ++ r) {
10783 if (in_row + r < in_numrows) {
108- uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
84+ uint scalar_idx = (in_row + r) * in_num_int8_cols + in_int8_col;
85+ $if NO_INT8_BUFFERS:
86+ BUF_T in_val_packed_texel = nchw_4x2[scalar_idx >> 2 ];
87+ const uint packed_idx = scalar_idx % 4 ;
88+ uint in_val_packed = extract_comp(in_val_packed_texel, packed_idx);
89+ $else :
90+ BUF_T in_val_packed = nchw_4x2[scalar_idx];
91+
10992 in_vals[r][0 ] = get_first(in_val_packed);
11093 in_vals[r][1 ] = get_second(in_val_packed);
11194 } else {
112- in_vals[r][0 ] = uint8_t (0 );
113- in_vals[r][1 ] = uint8_t (0 );
95+ in_vals[r][0 ] = uint (0 );
96+ in_vals[r][1 ] = uint (0 );
11497 }
11598 }
11699
117- u8vec4 out_tex_1 = u8vec4 (
100+ UVEC4_T out_tex_1 = UVEC4_T (
118101 combine(in_vals[0 ][0 ], in_vals[4 ][0 ]),
119102 combine(in_vals[1 ][0 ], in_vals[5 ][0 ]),
120103 combine(in_vals[2 ][0 ], in_vals[6 ][0 ]),
121104 combine(in_vals[3 ][0 ], in_vals[7 ][0 ]));
122105
123- u8vec4 out_tex_2 = u8vec4 (
106+ UVEC4_T out_tex_2 = UVEC4_T (
124107 combine(in_vals[0 ][1 ], in_vals[4 ][1 ]),
125108 combine(in_vals[1 ][1 ], in_vals[5 ][1 ]),
126109 combine(in_vals[2 ][1 ], in_vals[6 ][1 ]),
0 commit comments