| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#version 450 core  | 
 | 10 | + | 
 | 11 | +#define PRECISION ${PRECISION}  | 
 | 12 | + | 
 | 13 | +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}  | 
 | 14 | +#define FLOAT_T ${buffer_scalar_type(DTYPE)}  | 
 | 15 | + | 
 | 16 | +${define_active_storage_type(STORAGE)}  | 
 | 17 | + | 
 | 18 | +${define_required_extensions(DTYPE)}  | 
 | 19 | +${define_required_extensions("int8")}  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +$if BATCH_MODE:  | 
 | 23 | +  #define BATCH_MODE  | 
 | 24 | + | 
 | 25 | +#define TILE_ROWS ${TILE_ROWS}  | 
 | 26 | +#define FOUR 4  | 
 | 27 | + | 
 | 28 | +// we avoid mat4 and vec4 usage here as they compile to much less efficient  | 
 | 29 | +// SPIR-V  | 
 | 30 | +struct FloatMatrix_2d {  | 
 | 31 | +  float data[TILE_ROWS][FOUR];  | 
 | 32 | +};  | 
 | 33 | + | 
 | 34 | +struct FloatMatrix_3d {  | 
 | 35 | +  float data[TILE_ROWS][FOUR][FOUR];  | 
 | 36 | +};  | 
 | 37 | + | 
 | 38 | +#ifdef BATCH_MODE  | 
 | 39 | +  #define FloatMatrix FloatMatrix_3d  | 
 | 40 | +#else  | 
 | 41 | +  #define FloatMatrix FloatMatrix_2d  | 
 | 42 | +#endif  | 
 | 43 | + | 
 | 44 | +#include "indexing_utils.h"  | 
 | 45 | + | 
 | 46 | +layout(std430) buffer;  | 
 | 47 | + | 
 | 48 | +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}  | 
 | 49 | +${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}  | 
 | 50 | +${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}  | 
 | 51 | +${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)}  | 
 | 52 | + | 
 | 53 | +$if STORAGE == "buffer":  | 
 | 54 | +  ${layout_declare_ubo(4, "ivec4", "out_sizes")}  | 
 | 55 | +  ${layout_declare_ubo(5, "ivec4", "out_strides")}  | 
 | 56 | +  ${layout_declare_ubo(6, "int", "out_numel")}  | 
 | 57 | +  ${layout_declare_ubo(7, "ivec4", "mat1_sizes")}  | 
 | 58 | +  ${layout_declare_ubo(8, "ivec4", "mat1_strides")}  | 
 | 59 | +  ${layout_declare_ubo(9, "ivec4", "qmat2_strides")}  | 
 | 60 | +  ${layout_declare_ubo(10, "ivec4", "scales_strides")}  | 
 | 61 | +$else:  | 
 | 62 | +  ${layout_declare_ubo(4, "ivec3", "out_limits")}  | 
 | 63 | +  ${layout_declare_ubo(5, "ivec4", "mat1_sizes")}  | 
 | 64 | + | 
 | 65 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;  | 
 | 66 | + | 
 | 67 | +// This header file must be defined after the layout descriptors have been  | 
 | 68 | +// declared because the functions in the header assume some variables have been  | 
 | 69 | +// declared as layout descriptors.  | 
 | 70 | + | 
 | 71 | +#ifdef USING_BUFFER  | 
 | 72 | + | 
 | 73 | +#ifndef FLOAT_T  | 
 | 74 | +#define FLOAT_T float  | 
 | 75 | +#endif  | 
 | 76 | + | 
 | 77 | +FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {  | 
 | 78 | +  const FLOAT_T scale = t_scales[out_idx.x];  | 
 | 79 | + | 
 | 80 | +  FLOAT_T outval = FLOAT_T(0.0);  | 
 | 81 | + | 
 | 82 | +  // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0)  | 
 | 83 | +  int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z;  | 
 | 84 | +  // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2  | 
 | 85 | +  // tensor is transposed  | 
 | 86 | +  int qmat2_offset = out_idx.x * qmat2_strides.y;  | 
 | 87 | + | 
 | 88 | +  // TODO(ssjia): optimize memory access pattern by traversing K in inner loop  | 
 | 89 | +  for (int i = 0; i < K; i++) {  | 
 | 90 | +    const FLOAT_T mat1_val = t_mat1[mat1_offset];  | 
 | 91 | +    const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;  | 
 | 92 | + | 
 | 93 | +    outval += mat1_val * mat2_val;  | 
 | 94 | + | 
 | 95 | +    mat1_offset++;  | 
 | 96 | +    qmat2_offset++;  | 
 | 97 | +  }  | 
 | 98 | + | 
 | 99 | +  return outval;  | 
 | 100 | +}  | 
 | 101 | + | 
 | 102 | +void main() {  | 
 | 103 | +  const int out_bufi = int(gl_GlobalInvocationID.x);  | 
 | 104 | +  if (out_bufi >= out_numel) {  | 
 | 105 | +    return;  | 
 | 106 | +  }  | 
 | 107 | + | 
 | 108 | +  const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);  | 
 | 109 | + | 
 | 110 | +  t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x);  | 
 | 111 | +}  | 
 | 112 | + | 
 | 113 | +#else // USING_TEXTURE  | 
 | 114 | +FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) {  | 
 | 115 | +  FloatMatrix results;  | 
 | 116 | +  for (int i = 0; i < TILE_ROWS; i++) {  | 
 | 117 | +    for (int j = 0; j < FOUR; j++) {  | 
 | 118 | +#ifdef BATCH_MODE  | 
 | 119 | +      for (int k = 0; k < FOUR; k++) {  | 
 | 120 | +        results.data[i][j][k] = 0.0f;  | 
 | 121 | +      }  | 
 | 122 | +#else  | 
 | 123 | +      results.data[i][j] = 0.0f;  | 
 | 124 | +#endif // BATCH_MODE  | 
 | 125 | +    }  | 
 | 126 | +  }  | 
 | 127 | + | 
 | 128 | +  VEC4_T im_mat1_partial_load[TILE_ROWS];  | 
 | 129 | +  VEC4_T im_mat2_partial_load[FOUR];  | 
 | 130 | + | 
 | 131 | +#ifdef BATCH_MODE  | 
 | 132 | +  for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {  | 
 | 133 | +    if (out_idx_tl.z + batch_idx >= out_limits.z) {  | 
 | 134 | +      break;  | 
 | 135 | +    }  | 
 | 136 | +#endif  | 
 | 137 | +    for (int k = 0; k < mat1_sizes.x; k++) {  | 
 | 138 | +      for (int r = 0; r < TILE_ROWS; r++) {  | 
 | 139 | +        ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0);  | 
 | 140 | +#ifdef BATCH_MODE  | 
 | 141 | +        mat1_pos[2] = out_idx_tl.z + batch_idx;  | 
 | 142 | +#endif  | 
 | 143 | + | 
 | 144 | +        im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0);  | 
 | 145 | +      }  | 
 | 146 | + | 
 | 147 | +      for (int r = 0; r < FOUR; ++r) {  | 
 | 148 | +        ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0);  | 
 | 149 | + | 
 | 150 | +        im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0);  | 
 | 151 | +      }  | 
 | 152 | + | 
 | 153 | +      vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0);  | 
 | 154 | + | 
 | 155 | +      // perform partial dot products and add partial result to results  | 
 | 156 | +      for (int out_row = 0; out_row < TILE_ROWS; out_row++) {  | 
 | 157 | +        for (int out_col = 0; out_col < FOUR; out_col++) {  | 
 | 158 | +#ifdef BATCH_MODE  | 
 | 159 | +          results.data[out_row][out_col][batch_idx] +=  | 
 | 160 | +#else  | 
 | 161 | +        results.data[out_row][out_col] +=  | 
 | 162 | +#endif  | 
 | 163 | +              dot(im_mat1_partial_load[out_row],  | 
 | 164 | +                  im_mat2_partial_load[out_col] * scales[out_col]);  | 
 | 165 | +        }  | 
 | 166 | +      }  | 
 | 167 | +    }  | 
 | 168 | +#ifdef BATCH_MODE  | 
 | 169 | +  }  | 
 | 170 | +#endif  | 
 | 171 | +  return results;  | 
 | 172 | +}  | 
 | 173 | + | 
 | 174 | +void main() {  | 
 | 175 | +  const ivec3 out_idx = ivec3(gl_GlobalInvocationID);  | 
 | 176 | +  if (any(greaterThanEqual(out_idx, out_limits))) {  | 
 | 177 | +    return;  | 
 | 178 | +  }  | 
 | 179 | + | 
 | 180 | +  FloatMatrix results = q_8w_linear_optimized(out_idx);  | 
 | 181 | + | 
 | 182 | +  ivec3 out_pos = ivec3(  | 
 | 183 | +      out_idx.x,  | 
 | 184 | +      out_idx.y * TILE_ROWS,  | 
 | 185 | +#ifdef BATCH_MODE  | 
 | 186 | +      out_idx.z * 4  | 
 | 187 | +#else  | 
 | 188 | +      out_idx.z  | 
 | 189 | +#endif  | 
 | 190 | +);  | 
 | 191 | + | 
 | 192 | +  for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) {  | 
 | 193 | +    out_pos.x = out_idx.x;  | 
 | 194 | +    $if BATCH_MODE:  | 
 | 195 | +      for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) {  | 
 | 196 | +        write_texel(t_out, out_pos, VEC4_T(  | 
 | 197 | +              results.data[idx_c][idx_r][0],  | 
 | 198 | +              results.data[idx_c][idx_r][1],  | 
 | 199 | +              results.data[idx_c][idx_r][2],  | 
 | 200 | +              results.data[idx_c][idx_r][3]));  | 
 | 201 | +      }  | 
 | 202 | +    $else:  | 
 | 203 | +      write_texel(t_out, out_pos, VEC4_T(  | 
 | 204 | +              results.data[idx_c][0],  | 
 | 205 | +              results.data[idx_c][1],  | 
 | 206 | +              results.data[idx_c][2],  | 
 | 207 | +              results.data[idx_c][3]));  | 
 | 208 | +  }  | 
 | 209 | +}  | 
 | 210 | + | 
 | 211 | +#endif  | 
0 commit comments