|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +#define PRECISION ${PRECISION} |
| 12 | +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} |
| 13 | +#define T int |
| 14 | + |
| 15 | +$if OUTPUT_STORAGE == "buffer": |
| 16 | + #define OUTPUT_BUFFER |
| 17 | +$if PACKED_INT8_INPUT_STORAGE == "buffer": |
| 18 | + #define PACKED_INT8_INPUT_BUFFER |
| 19 | +$if WEIGHT_STORAGE == "buffer": |
| 20 | + #define WEIGHT_BUFFER |
| 21 | + |
| 22 | +#define TILE_M4 ${TILE_M4} |
| 23 | +#define TILE_K4 ${TILE_K4} |
| 24 | +#define TILE_N4 ${TILE_N4} |
| 25 | + |
| 26 | +#define TILE_M ${TILE_M4 * 4} |
| 27 | +#define TILE_K ${TILE_K4 * 4} |
| 28 | +#define TILE_N ${TILE_N4 * 4} |
| 29 | + |
| 30 | +#define M_TILES_PER_WG 8 |
| 31 | +#define N_TILES_PER_WG 8 |
| 32 | +#define K_TILES_PER_WG 1 |
| 33 | + |
| 34 | +${define_required_extensions(DTYPE)} |
| 35 | + |
| 36 | +layout(std430) buffer; |
| 37 | + |
| 38 | +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} |
| 39 | +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} |
| 40 | +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} |
| 41 | +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} |
| 42 | +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} |
| 43 | +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} |
| 44 | + |
| 45 | +${layout_declare_spec_const(C, "int", "apply_bias", "0")} |
| 46 | + |
| 47 | +${layout_declare_ubo(B, "ivec4", "output_sizes")} |
| 48 | +${layout_declare_ubo(B, "ivec4", "input_sizes")} |
| 49 | + |
| 50 | +layout(push_constant) uniform restrict Block { |
| 51 | + float input_scale; |
| 52 | + int input_zp; |
| 53 | +}; |
| 54 | + |
| 55 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 56 | + |
| 57 | +#include "linear_int8_input_tile_load.glslh" |
| 58 | +#include "linear_int8_weight_tile_load.glslh" |
| 59 | +#include "linear_fp_output_tile_int8_int8_compute.glslh" |
| 60 | +#include "linear_fp_output_tile_store.glslh" |
| 61 | +#include "linear_fp_weight_scales_load.glslh" |
| 62 | +#include "linear_int_weight_sums_load.glslh" |
| 63 | +#include "linear_fp_bias_load.glslh" |
| 64 | + |
| 65 | +shared Int32Accum partial_sums[M_TILES_PER_WG][N_TILES_PER_WG][K_TILES_PER_WG]; |
| 66 | + |
| 67 | +void add_into_first(inout Int32Accum first, const Int32Accum second) { |
| 68 | + [[unroll]] for (int m = 0; m < TILE_M; ++m) { |
| 69 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 70 | + first.data[m][n4] += second.data[m][n4]; |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +void main() { |
| 76 | + const int m_tile_lid = int(gl_LocalInvocationID.x); |
| 77 | + const int n_tile_lid = int(gl_LocalInvocationID.y); |
| 78 | + const int k_tile_lid = int(gl_LocalInvocationID.z); |
| 79 | + |
| 80 | + // Each thread writes out a 4 wide x 4 high tile of output values |
| 81 | + const int out_tile_x = int(gl_GlobalInvocationID.x); |
| 82 | + const int out_tile_y = int(gl_GlobalInvocationID.y); |
| 83 | + |
| 84 | + const int n = out_tile_x * TILE_N; |
| 85 | + const int m = out_tile_y * TILE_M; |
| 86 | + |
| 87 | + const int n4 = div_4(n); |
| 88 | + const int m4 = div_4(m); |
| 89 | + |
| 90 | + if (n >= output_sizes.x || m >= output_sizes.y) { |
| 91 | + return; |
| 92 | + } |
| 93 | + |
| 94 | + const int M = output_sizes.y; |
| 95 | + const int K4 = div_up_4(input_sizes.x); |
| 96 | + const int N4 = div_up_4(output_sizes.x); |
| 97 | + |
| 98 | + Int32Accum out_accum; |
| 99 | + initialize(out_accum); |
| 100 | + |
| 101 | + Int8InputTile int8_in_tile; |
| 102 | + Int8WeightTile int8_weight_tile; |
| 103 | + |
| 104 | + const int k4_per_iter = TILE_K4 * K_TILES_PER_WG; |
| 105 | + |
| 106 | + // No checks are needed since packed input and weight are structured in units |
| 107 | + // of 4x4 blocks. |
| 108 | + for (int k4 = k_tile_lid; k4 < K4; k4 += k4_per_iter) { |
| 109 | + load_int8_input_tile(int8_in_tile, k4, m4, K4); |
| 110 | + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); |
| 111 | + |
| 112 | + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); |
| 113 | + } |
| 114 | + |
| 115 | + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid] = out_accum; |
| 116 | + |
| 117 | + memoryBarrierShared(); |
| 118 | + barrier(); |
| 119 | + |
| 120 | + // Tree reduction to compute the overall result. |
| 121 | + for (int i = K_TILES_PER_WG / 2; i > 0; i /= 2) { |
| 122 | + if (k_tile_lid < i) { |
| 123 | + add_into_first( |
| 124 | + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid], |
| 125 | + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid + i]); |
| 126 | + } |
| 127 | + memoryBarrierShared(); |
| 128 | + barrier(); |
| 129 | + } |
| 130 | + |
| 131 | + if (k_tile_lid > 0) { |
| 132 | + return; |
| 133 | + } |
| 134 | + |
| 135 | + out_accum = partial_sums[m_tile_lid][n_tile_lid][0]; |
| 136 | + |
| 137 | + FPPerOutChannelParams weight_scales_tile; |
| 138 | + load_weight_scales_tile(weight_scales_tile, n4); |
| 139 | + |
| 140 | + IntPerOutChannelParams weight_sums_tile; |
| 141 | + load_weight_sums_tile(weight_sums_tile, n4); |
| 142 | + |
| 143 | + FPOutTile out_tile; |
| 144 | + initialize(out_tile); |
| 145 | + |
| 146 | + if (apply_bias > 0) { |
| 147 | + FPPerOutChannelParams bias_tile; |
| 148 | + load_bias_tile(bias_tile, n4); |
| 149 | + |
| 150 | + accumulate_out_tile_with_int_accum( |
| 151 | + out_tile, |
| 152 | + out_accum, |
| 153 | + input_scale, |
| 154 | + input_zp, |
| 155 | + weight_sums_tile, |
| 156 | + weight_scales_tile, |
| 157 | + bias_tile); |
| 158 | + } |
| 159 | + else { |
| 160 | + accumulate_out_tile_with_int_accum( |
| 161 | + out_tile, |
| 162 | + out_accum, |
| 163 | + input_scale, |
| 164 | + input_zp, |
| 165 | + weight_sums_tile, |
| 166 | + weight_scales_tile); |
| 167 | + } |
| 168 | + |
| 169 | + if (M - m >= TILE_M) { |
| 170 | + write_output_tile_no_checks(out_tile, n4, m, N4, M); |
| 171 | + } else { |
| 172 | + write_output_tile_with_checks(out_tile, n4, m, N4, M); |
| 173 | + } |
| 174 | +} |
0 commit comments