Skip to content

Commit c3d5ff6

Browse files
committed
[ET-VK][DO NOT LAND] Experimental smem shader for int8 matmul
1 parent 447157a commit c3d5ff6

File tree

3 files changed

+203
-3
lines changed

3 files changed

+203
-3
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)}
13+
#define T int
14+
15+
$if OUTPUT_STORAGE == "buffer":
16+
#define OUTPUT_BUFFER
17+
$if PACKED_INT8_INPUT_STORAGE == "buffer":
18+
#define PACKED_INT8_INPUT_BUFFER
19+
$if WEIGHT_STORAGE == "buffer":
20+
#define WEIGHT_BUFFER
21+
22+
#define TILE_M4 ${TILE_M4}
23+
#define TILE_K4 ${TILE_K4}
24+
#define TILE_N4 ${TILE_N4}
25+
26+
#define TILE_M ${TILE_M4 * 4}
27+
#define TILE_K ${TILE_K4 * 4}
28+
#define TILE_N ${TILE_N4 * 4}
29+
30+
#define M_TILES_PER_WG 8
31+
#define N_TILES_PER_WG 8
32+
#define K_TILES_PER_WG 1
33+
34+
${define_required_extensions(DTYPE)}
35+
36+
layout(std430) buffer;
37+
38+
${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
39+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)}
40+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
41+
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
42+
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
43+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
44+
45+
${layout_declare_spec_const(C, "int", "apply_bias", "0")}
46+
47+
${layout_declare_ubo(B, "ivec4", "output_sizes")}
48+
${layout_declare_ubo(B, "ivec4", "input_sizes")}
49+
50+
layout(push_constant) uniform restrict Block {
51+
float input_scale;
52+
int input_zp;
53+
};
54+
55+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
56+
57+
#include "linear_int8_input_tile_load.glslh"
58+
#include "linear_int8_weight_tile_load.glslh"
59+
#include "linear_fp_output_tile_int8_int8_compute.glslh"
60+
#include "linear_fp_output_tile_store.glslh"
61+
#include "linear_fp_weight_scales_load.glslh"
62+
#include "linear_int_weight_sums_load.glslh"
63+
#include "linear_fp_bias_load.glslh"
64+
65+
shared Int32Accum partial_sums[M_TILES_PER_WG][N_TILES_PER_WG][K_TILES_PER_WG];
66+
67+
void add_into_first(inout Int32Accum first, const Int32Accum second) {
68+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
69+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
70+
first.data[m][n4] += second.data[m][n4];
71+
}
72+
}
73+
}
74+
75+
void main() {
76+
const int m_tile_lid = int(gl_LocalInvocationID.x);
77+
const int n_tile_lid = int(gl_LocalInvocationID.y);
78+
const int k_tile_lid = int(gl_LocalInvocationID.z);
79+
80+
// Each thread writes out a 4 wide x 4 high tile of output values
81+
const int out_tile_x = int(gl_GlobalInvocationID.x);
82+
const int out_tile_y = int(gl_GlobalInvocationID.y);
83+
84+
const int n = out_tile_x * TILE_N;
85+
const int m = out_tile_y * TILE_M;
86+
87+
const int n4 = div_4(n);
88+
const int m4 = div_4(m);
89+
90+
if (n >= output_sizes.x || m >= output_sizes.y) {
91+
return;
92+
}
93+
94+
const int M = output_sizes.y;
95+
const int K4 = div_up_4(input_sizes.x);
96+
const int N4 = div_up_4(output_sizes.x);
97+
98+
Int32Accum out_accum;
99+
initialize(out_accum);
100+
101+
Int8InputTile int8_in_tile;
102+
Int8WeightTile int8_weight_tile;
103+
104+
const int k4_per_iter = TILE_K4 * K_TILES_PER_WG;
105+
106+
// No checks are needed since packed input and weight are structured in units
107+
// of 4x4 blocks.
108+
for (int k4 = k_tile_lid; k4 < K4; k4 += k4_per_iter) {
109+
load_int8_input_tile(int8_in_tile, k4, m4, K4);
110+
load_int8_weight_tile(int8_weight_tile, n4, k4, N4);
111+
112+
int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile);
113+
}
114+
115+
partial_sums[m_tile_lid][n_tile_lid][k_tile_lid] = out_accum;
116+
117+
memoryBarrierShared();
118+
barrier();
119+
120+
// Tree reduction to compute the overall result.
121+
for (int i = K_TILES_PER_WG / 2; i > 0; i /= 2) {
122+
if (k_tile_lid < i) {
123+
add_into_first(
124+
partial_sums[m_tile_lid][n_tile_lid][k_tile_lid],
125+
partial_sums[m_tile_lid][n_tile_lid][k_tile_lid + i]);
126+
}
127+
memoryBarrierShared();
128+
barrier();
129+
}
130+
131+
if (k_tile_lid > 0) {
132+
return;
133+
}
134+
135+
out_accum = partial_sums[m_tile_lid][n_tile_lid][0];
136+
137+
FPPerOutChannelParams weight_scales_tile;
138+
load_weight_scales_tile(weight_scales_tile, n4);
139+
140+
IntPerOutChannelParams weight_sums_tile;
141+
load_weight_sums_tile(weight_sums_tile, n4);
142+
143+
FPOutTile out_tile;
144+
initialize(out_tile);
145+
146+
if (apply_bias > 0) {
147+
FPPerOutChannelParams bias_tile;
148+
load_bias_tile(bias_tile, n4);
149+
150+
accumulate_out_tile_with_int_accum(
151+
out_tile,
152+
out_accum,
153+
input_scale,
154+
input_zp,
155+
weight_sums_tile,
156+
weight_scales_tile,
157+
bias_tile);
158+
}
159+
else {
160+
accumulate_out_tile_with_int_accum(
161+
out_tile,
162+
out_accum,
163+
input_scale,
164+
input_zp,
165+
weight_sums_tile,
166+
weight_scales_tile);
167+
}
168+
169+
if (M - m >= TILE_M) {
170+
write_output_tile_no_checks(out_tile, n4, m, N4, M);
171+
} else {
172+
write_output_tile_with_checks(out_tile, n4, m, N4, M);
173+
}
174+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
linear_q8ta_q8csw_tiled_smem:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
OUTPUT_STORAGE: texture3d
11+
PACKED_INT8_INPUT_STORAGE: buffer
12+
WEIGHT_STORAGE: texture2d
13+
TILE_M4: 1
14+
TILE_N4: 2
15+
TILE_K4: 1
16+
generate_variant_forall:
17+
combination:
18+
parameter_names: [OUTPUT_STORAGE, PACKED_INT8_INPUT_STORAGE, WEIGHT_STORAGE]
19+
combos:
20+
- parameter_values: [texture3d, buffer, texture2d]
21+
- parameter_values: [buffer, buffer, texture2d]
22+
DTYPE:
23+
- VALUE: float
24+
shader_variants:
25+
- NAME: linear_q8ta_q8csw_tiled_smem

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ utils::uvec3 quantized_linear_local_wg_size(
100100
if (use_coop_algorithm) {
101101
return {1, 1, 64};
102102
} else {
103-
return pick_hw_square_wg_size(
104-
graph, shader, global_workgroup_size, args, resize_args);
103+
// return pick_hw_square_wg_size(
104+
// graph, shader, global_workgroup_size, args, resize_args);
105+
return {8, 8, 1};
105106
}
106107
}
107108

@@ -595,7 +596,7 @@ DynamicDispatchNode make_linear_qa_qw_node(
595596
int32_t zp = graph.extract_scalar<int32_t>(input_zp_data);
596597

597598
// Get shader for quantized linear
598-
std::string kernel_name = "linear_q8ta_q8csw_tiled";
599+
std::string kernel_name = "linear_q8ta_q8csw_tiled_smem";
599600
add_storage_type_suffix(kernel_name, graph.storage_type_of(output));
600601
add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_int_input));
601602
add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight));

0 commit comments

Comments
 (0)