Skip to content

Commit 54f5ffe

Browse files
pytorchbotssjia
andauthored
[ET-VK] Improve q8 matmul by increasing TILE_N4 (#14610)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14597 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/331/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/331/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/329/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/331/orig Differential Revision: [D83253129](https://our.internmc.facebook.com/intern/diff/D83253129/) @diff-train-skip-merge Co-authored-by: ssjia <[email protected]>
1 parent 681680e commit 54f5ffe

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum(
7575
input_zp_vec * weight_sums.data[n4] + accum.data[m][n4];
7676
out_tile.data[m][n4] =
7777
fma(VEC4_T(accum_adjusted),
78-
VEC4_T(input_q_scale * weight_scales.data[0]),
78+
VEC4_T(input_q_scale * weight_scales.data[n4]),
7979
out_tile.data[m][n4]);
8080
}
8181
}

backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ linear_q8ta_q8csw_tiled:
1111
PACKED_INT8_INPUT_STORAGE: buffer
1212
WEIGHT_STORAGE: texture2d
1313
TILE_M4: 1
14-
TILE_N4: 1
14+
TILE_N4: 2
1515
TILE_K4: 1
1616
generate_variant_forall:
1717
DTYPE:

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ utils::uvec3 quantized_linear_global_wg_size(
7777
M_per_tile = 1;
7878
}
7979

80+
if (shader.kernel_name.find("q8ta_q8csw_tiled") != std::string::npos) {
81+
N_per_tile = 8;
82+
}
83+
8084
const uint32_t num_N_tiles = utils::div_up(N, N_per_tile);
8185
const uint32_t num_M_tiles = utils::div_up(M, M_per_tile);
8286

0 commit comments

Comments
 (0)