Skip to content

Commit 3f57577

Browse files
trivediviveklucylq
authored andcommitted
Converting all uint16 to int in quantized mat mul shader to improve perf.
Differential Revision: D84777696 Pull Request resolved: #15193
1 parent 419727c commit 3f57577

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ ${define_required_extensions(DTYPE)}
2121
$if WEIGHT_STORAGE == "buffer":
2222
${define_required_extensions("int8")}
2323

24-
#extension GL_EXT_control_flow_attributes : require
25-
2624
layout(std430) buffer;
2725

2826
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
@@ -49,20 +47,18 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4947
void main() {
5048
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
5149
$if TILE_TXCOLS > 1:
52-
const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS));
53-
const uint16_t out_txcol = uint16_t(
54-
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
50+
const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS);
51+
const int out_txcol = (int(gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS;
5552
$else:
56-
const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x));
57-
const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x);
53+
const int global_wg_x = divup4(out_sizes.x);
54+
const int out_txcol = int(gl_GlobalInvocationID.x) % global_wg_x;
5855

59-
const uint16_t out_row = uint16_t(
60-
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
56+
const int out_row = (int(gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS;
6157

6258
$if QUANT_NBITS == 4:
63-
const uint16_t weight_txcol = uint16_t(out_txcol / 2);
59+
const int weight_txcol = out_txcol / 2;
6460

65-
if (out_row >= uint16_t(out_sizes.y)) {
61+
if (out_row >= int(out_sizes.y)) {
6662
return;
6763
}
6864

@@ -73,9 +69,9 @@ void main() {
7369
sums[r][${c}] = VEC4_T(0.0);
7470
}
7571

76-
for (uint16_t pos = uint16_t(0), txpos = uint16_t(0);
77-
pos < uint16_t(in_sizes.x);
78-
pos += uint16_t(4), txpos += uint16_t(1)) {
72+
for (int pos = 0, txpos = 0;
73+
pos < in_sizes.x;
74+
pos += 4, txpos += 1) {
7975

8076
T mat1[TILE_ROWS][4];
8177

@@ -91,7 +87,7 @@ void main() {
9187
mat1[i][2] = tmp.z;
9288
mat1[i][3] = tmp.w;
9389
$else:
94-
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
90+
VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
9591
mat1[i][0] = tmp.x;
9692
mat1[i][1] = tmp.y;
9793
mat1[i][2] = tmp.z;
@@ -117,7 +113,7 @@ void main() {
117113
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
118114
$else:
119115
packed_weight_tex = texelFetch(
120-
t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0);
116+
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
121117

122118
qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0);
123119
qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
@@ -128,7 +124,7 @@ void main() {
128124
qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
129125
$else:
130126
qmat2[${c}] = VEC4_T(
131-
texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0));
127+
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
132128

133129
for (int tr = 0; tr < TILE_ROWS; ++tr) {
134130
$for c in range(TILE_TXCOLS):
@@ -143,7 +139,7 @@ void main() {
143139
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
144140
$else:
145141
scales[${c}] = VEC4_T(
146-
texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0));
142+
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));
147143

148144
// Store to output tensor
149145
$if OUT_STORAGE == "buffer":

extension/llm/runner/text_llm_runner.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ Error TextLLMRunner::generate(
9898
std::function<void(const std::string&)> wrapped_callback =
9999
[token_callback, config](const std::string& piece) {
100100
if (!config.warming) {
101+
// llm::safe_printf("\033[32m");
101102
llm::safe_printf(piece.c_str());
103+
// llm::safe_printf("\033[0m\n");
104+
// \033[32mThis text is green.\033[0m\n
102105
fflush(stdout);
103106
}
104107
if (token_callback) {
@@ -169,6 +172,11 @@ Error TextLLMRunner::generate(
169172
stats_->first_token_ms = time_in_ms();
170173
stats_->prompt_eval_end_ms = time_in_ms();
171174

175+
RUNNER_ET_LOG(
176+
config.warming,
177+
"RSS after prompt prefill: %f MiB (0 if unsupported)",
178+
get_rss_bytes() / 1024.0 / 1024.0);
179+
172180
// print the first token from prefill. No prev_token so use cur_token for it.
173181
auto decode_result = tokenizer_->decode(cur_token, cur_token);
174182
if (!decode_result.ok()) {
@@ -179,10 +187,6 @@ Error TextLLMRunner::generate(
179187
return ::executorch::runtime::Error::InvalidArgument;
180188
}
181189
wrapped_callback(std::move(*decode_result));
182-
RUNNER_ET_LOG(
183-
config.warming,
184-
"RSS after prompt prefill: %f MiB (0 if unsupported)",
185-
get_rss_bytes() / 1024.0 / 1024.0);
186190

187191
// start the main loop
188192
prompt_tokens.push_back(cur_token);

extension/llm/runner/text_token_generator.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,13 @@ class ET_EXPERIMENTAL TextTokenGenerator {
128128
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
129129
printf("\n");
130130
ET_LOG(Info, "\nReached to the end of generation");
131-
break;
131+
return pos - start_pos;
132132
}
133133
}
134+
ET_LOG(
135+
Info,
136+
"\nFinished generation. Generated %" PRIi32 " tokens.",
137+
start_pos + max_new_tokens);
134138
return pos - start_pos;
135139
}
136140

0 commit comments

Comments
 (0)