| 
19 | 19 | 
 
  | 
20 | 20 | ${define_active_storage_type(STORAGE)}  | 
21 | 21 | 
 
  | 
22 |  | -${define_required_extensions(DTYPE)}  | 
23 |  | -${define_required_extensions("int8")}  | 
 | 22 | +${define_required_extensions([DTYPE, "uint8", "uint16"])}  | 
 | 23 | +#extension GL_EXT_control_flow_attributes : require  | 
24 | 24 | 
 
  | 
25 | 25 | layout(std430) buffer;  | 
26 | 26 | 
 
  | 
27 |  | -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}  | 
28 |  | -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}  | 
29 |  | -${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")}  | 
30 |  | -${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}  | 
31 |  | - | 
32 |  | -$if STORAGE == "texture3d":  | 
33 |  | -  ${layout_declare_ubo(4, "ivec4", "out_sizes")}  | 
34 |  | -  ${layout_declare_ubo(5, "ivec4", "mat1_sizes")}  | 
35 |  | -  ${layout_declare_ubo(6, "ivec4", "mat2_strides")}  | 
36 |  | -  ${layout_declare_ubo(7, "ivec4", "scales_strides")}  | 
37 |  | -$else:  | 
38 |  | -  ${layout_declare_ubo(4, "ivec4", "out_sizes")}  | 
39 |  | -  ${layout_declare_ubo(5, "ivec4", "out_strides")}  | 
40 |  | -  ${layout_declare_ubo(6, "ivec4", "mat1_sizes")}  | 
41 |  | -  ${layout_declare_ubo(7, "ivec4", "mat1_strides")}  | 
42 |  | -  ${layout_declare_ubo(8, "ivec4", "mat2_strides")}  | 
43 |  | -  ${layout_declare_ubo(9, "ivec4", "scales_strides")}  | 
 | 27 | +${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}  | 
 | 28 | +${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}  | 
 | 29 | +${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")}  | 
 | 30 | +${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}  | 
 | 31 | +${layout_declare_ubo(B, "ivec3", "ret_limits")}  | 
 | 32 | +${layout_declare_ubo(B, "ivec4", "x_sizes")}  | 
 | 33 | +${layout_declare_ubo(B, "ivec4", "weights_strides")}  | 
 | 34 | +${layout_declare_ubo(B, "ivec4", "qparams_strides")}  | 
44 | 35 | 
 
  | 
45 | 36 | layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;  | 
46 | 37 | 
 
  | 
47 | 38 | layout(constant_id = 3) const int group_size = 1;  | 
48 | 39 | 
 
  | 
 | 40 | +/*  | 
 | 41 | + * This shader computes a linear operator between a floating point input matrix  | 
 | 42 | + * x and a weights matrix that is quantized to 4 bits.  | 
 | 43 | + *  | 
 | 44 | + * The (W, H, C) shape of each tensor is:  | 
 | 45 | + * - x: (K, M)  | 
 | 46 | + * - weights: (K / 2, N)  | 
 | 47 | + *   - The weights tensor has a data type of `uint8`. Each element in the tensor  | 
 | 48 | + *     contains 2 4-bit values packed into a uint8.  | 
 | 49 | + * - qparams: (2, N, number_of_groups)  | 
 | 50 | + *   - This tensor contains the scales and zeros quantization parameters for the  | 
 | 51 | + *     weights tensor. The weight tensor is quantized group-wise, which means  | 
 | 52 | + *     that every `group_size` elements along the K dimension of the weights  | 
 | 53 | + *     tensor has independent quantization parameters. Along the width dim, the  | 
 | 54 | + *     first value contains the scale for the group and the second value  | 
 | 55 | + *     contains the zero point for the group.  | 
 | 56 | + *  | 
 | 57 | + * Note that this shader assumes that all tensors are width packed.  | 
 | 58 | + */  | 
49 | 59 | void main() {  | 
50 |  | - | 
51 |  | -    const ivec4 out_pos = ivec4(  | 
52 |  | -      gl_GlobalInvocationID.x, // n = 0..N-1  | 
53 |  | -      gl_GlobalInvocationID.y, // m = 0..M-1  | 
54 |  | -      gl_GlobalInvocationID.z % out_sizes.z,  | 
55 |  | -      gl_GlobalInvocationID.z / out_sizes.z);  | 
56 |  | - | 
57 |  | -    if (any(greaterThanEqual(out_pos, out_sizes))) {  | 
58 |  | -      return;  | 
 | 60 | +  // output positions being calculated are (n, m), (n + 1, m), ...  | 
 | 61 | +  // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows  | 
 | 62 | +  // of the weights tensor.  | 
 | 63 | +  const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);  | 
 | 64 | +  if (any(greaterThanEqual(ret_pos, ret_limits))) {  | 
 | 65 | +    return;  | 
 | 66 | +  }  | 
 | 67 | + | 
 | 68 | +  // Since ret is width packed, need to multiply by 4  | 
 | 69 | +  const uint16_t n = uint16_t(ret_pos.x * 4);  | 
 | 70 | + | 
 | 71 | +  // K is guaranteed to be a multiple of group size  | 
 | 72 | +  const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);  | 
 | 73 | + | 
 | 74 | +  uint16_t k_texel_i = uint16_t(0);  | 
 | 75 | +  vec4 sums = vec4(0.0);  | 
 | 76 | +  for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) {  | 
 | 77 | +    vec4 scales;  | 
 | 78 | +    vec4 zeros;  | 
 | 79 | + | 
 | 80 | +    [[unroll]] for (int comp = 0; comp < 4; comp++) {  | 
 | 81 | +      const vec4 scale_and_zero = load_texel(  | 
 | 82 | +          qparams, u16vec3(0, n + comp, block_idx));  | 
 | 83 | +      scales[comp] = scale_and_zero.x;  | 
 | 84 | +      zeros[comp] = scale_and_zero.y;  | 
59 | 85 |     }  | 
60 | 86 | 
 
  | 
61 |  | -    const uint K = mat1_sizes.x;  | 
62 |  | -    const uint n = out_pos.x;  | 
63 |  | -    const uint m = out_pos.y;  | 
64 |  | -    const uint mask = uint(0x0f);  | 
65 |  | - | 
66 |  | -    float rc = 0.0;  | 
67 |  | -    int k = 0;  | 
68 |  | -    const uint k_block = (K + group_size - 1) / group_size;  | 
69 |  | - | 
70 |  | -    #ifdef USING_BUFFER  | 
71 |  | -      ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);  | 
72 |  | -      ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);  | 
73 |  | -      ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);  | 
74 |  | -      ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);  | 
75 |  | - | 
76 |  | -      for (int kb = 0; kb < k_block; kb++) {  | 
77 |  | -        scale_pos.x = kb;  | 
78 |  | -        const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides);  | 
79 |  | -        const float scale = float(t_scales_and_zeros[scale_bufi]);  | 
80 |  | - | 
81 |  | -        zero_pos.x = kb;  | 
82 |  | -        const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides);  | 
83 |  | -        const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0;  | 
84 |  | - | 
85 |  | -        for(uint idx = 0; idx < group_size && k < K; idx++, k++) {  | 
86 |  | -          mat1_pos.x = k;  | 
87 |  | -          const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides);  | 
88 |  | -          const float mat1_val = float(t_mat1[mat1_bufi]);  | 
89 |  | - | 
90 |  | -          mat2_pos.x = k / 2;  | 
91 |  | -          const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides);  | 
92 |  | -          // Bitwise op treats sign bit from int8 as a value bit instead,  | 
93 |  | -          // since there is no uint8_t datatype  | 
94 |  | -          uint mat2_val = (t_mat2[mat2_bufi] & 0xFF);  | 
95 |  | -          mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);  | 
96 |  | - | 
97 |  | -          rc += mat1_val * (scale * float(mat2_val) + zero);  | 
98 |  | -        }  | 
99 |  | -      }  | 
100 |  | - | 
101 |  | -      const int out_bufi = tidx_to_bufi(out_pos, out_strides);  | 
102 |  | -      t_out[out_bufi] = FLOAT_T(rc);  | 
103 |  | - | 
104 |  | -    #else // Using texture  | 
105 |  | -      ivec3 mat1_pos = ivec3(0, m, out_pos.z);  | 
106 |  | -      ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);  | 
107 |  | -      ivec3 scale_zero_pos = ivec3(0, n, 0);  | 
108 |  | -      uint K_texel = K / FOUR;  | 
109 |  | - | 
110 |  | -      for (int kb = 0; kb < k_block; kb++) {  | 
111 |  | -        scale_zero_pos.x = kb;  | 
112 |  | -        const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos);  | 
113 |  | -        const float scale = scale_zero.x;  | 
114 |  | -        const float zero = scale_zero.y - scale * 8.0;  | 
115 |  | - | 
116 |  | -        for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) {  | 
117 |  | -          mat1_pos.x = k;  | 
118 |  | -          const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);  | 
119 |  | - | 
120 |  | -          mat2_pos.x = k * 2; // k * FOUR / 2  | 
121 |  | -          const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides);  | 
122 |  | - | 
123 |  | -          for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) {  | 
124 |  | -            // Bitwise op treats sign bit from int8 as a value bit instead,  | 
125 |  | -            // since there is no uint8_t datatype  | 
126 |  | -            uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF);  | 
127 |  | -            mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);  | 
128 |  | -            rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero);  | 
129 |  | -          }  | 
130 |  | -        }  | 
 | 87 | +    for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) {  | 
 | 88 | +      const VEC4_T x_texel = load_texel(  | 
 | 89 | +          x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));  | 
 | 90 | + | 
 | 91 | +      [[unroll]] for (int comp = 0; comp < 4; comp++) {  | 
 | 92 | +        const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2);  | 
 | 93 | +        // Need to read 4 unpacked values, which corresponds to 2 packed values  | 
 | 94 | +        const uint8_t weights_val_1 = weights[weights_bufi];  | 
 | 95 | +        const uint8_t weights_val_2 = weights[weights_bufi + 1];  | 
 | 96 | + | 
 | 97 | +        const u8vec4 weights_texel = u8vec4(  | 
 | 98 | +          (weights_val_1 & 0xF0) >> 4,  | 
 | 99 | +          weights_val_1 & 0x0F,  | 
 | 100 | +          (weights_val_2 & 0xF0) >> 4,  | 
 | 101 | +          weights_val_2 & 0x0F);  | 
 | 102 | + | 
 | 103 | +        // Note that the unpacked 4-bit values are unsigned, therefore they must  | 
 | 104 | +        // first be "centered" around 0 by subtracting 8 before applying the  | 
 | 105 | +        // scale and zero point.  | 
 | 106 | +        sums[comp] += dot(  | 
 | 107 | +            x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]);  | 
131 | 108 |       }  | 
132 |  | -      write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0));  | 
133 |  | - | 
134 |  | -    #endif  | 
 | 109 | +    }  | 
 | 110 | +  }  | 
 | 111 | +  write_texel(ret, ret_pos, sums);  | 
135 | 112 | }  | 
0 commit comments