From 90d99e2204e4ef405296299b7a955ec53a4b3dbb Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Sat, 18 Oct 2025 18:44:33 -0700 Subject: [PATCH] Doubling tile texel col count for mat mul op to improve performance. (#15192) Summary: ### Summary This change doubled tile texel column count for 8 bit matrix multiplication operation to improve performance. Reviewed By: SS-JIA Differential Revision: D84679398 --- .../vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl | 9 +++++++-- .../vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml | 2 +- .../runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl | 9 +++++++-- .../runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml | 2 +- .../runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp | 4 ++-- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl index c766a3cd7d0..31e04c3a86a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl @@ -98,12 +98,17 @@ void main() { // Preload weight tensor [[unroll]] for (int r = 0; r < 4; r++) { $if QUANT_NBITS == 4: + $if WEIGHT_STORAGE == "buffer": + u8vec4 packed_weight_tex; + $else: + uvec4 packed_weight_tex; + $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; - const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + packed_weight_tex = t_weight[qmat2_bufi + ${c}] $else: - const uvec4 packed_weight_tex = texelFetch( + packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml index 3dff6855142..f05dc7104c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml @@ -12,7 +12,7 @@ linear_qcsnw_coop: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 - TILE_TXCOLS: 1 + TILE_TXCOLS: 2 QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index bc000580f76..936fd641a9b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -106,12 +106,17 @@ void main() { for (int r = 0; r < 4; r++) { VEC4_T qmat2[TILE_TXCOLS]; $if QUANT_NBITS == 4: + $if WEIGHT_STORAGE == "buffer": + u8vec4 packed_weight_tex; + $else: + uvec4 packed_weight_tex; + $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; - const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + packed_weight_tex = t_weight[qmat2_bufi + ${c}] $else: - const uvec4 packed_weight_tex = texelFetch( + packed_weight_tex = texelFetch( t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index 1c9ec4e524a..bbae284349a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -12,7 +12,7 @@ linear_qcsnw_tiled: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 - TILE_TXCOLS: 1 + TILE_TXCOLS: 2 QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 89c9e847724..971291cb11f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -73,7 +73,7 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size( } // Number of output texels in the output tile - uint32_t out_tile_ntxcols = 1; + uint32_t out_tile_ntxcols = 2; if (quant_nbits == 4) { out_tile_ntxcols = 2; } @@ -324,7 +324,7 @@ void add_linear_qcsnw_tiled_node( } // Number of output texels in the output tile - uint32_t out_tile_ntxcols = 1; + uint32_t out_tile_ntxcols = 2; if (quant_nbits == 4) { out_tile_ntxcols = 2; }