Skip to content

Commit 45154bf

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Improving 4bit quant mat mul performance by shifting position of -8 operation.
Summary: This diff introduces performance improvements in the 4-bit quant matrix multiplication operation by adjusting the position of the -8 operation. Resulting in overall reduction in math operation performed during shader runtime. Differential Revision: D85721578
1 parent 30d7cae commit 45154bf

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ void main() {
8686
VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
8787
$for j in range(4):
8888
mat1[i * 4 + ${j}] = mat1_vec4[${j}];
89+
90+
$if QUANT_NBITS == 4:
91+
// Apply -8 * mat1 bias here rather then below, to effectively reduce overall number of math operations performed during runtime.
92+
const T accum = mat1[i * 4 + 0] + mat1[i * 4 + 1] + mat1[i * 4 + 2] + mat1[i * 4 + 3];
93+
$for c in range(TILE_TXCOLS):
94+
$for j in range(4):
95+
sums[i * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += accum * -8.0;
8996
}
9097

9198
$if WEIGHT_STORAGE == "buffer":
@@ -109,13 +116,13 @@ void main() {
109116
packed_weight_tex = texelFetch(
110117
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
111118

112-
qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0);
119+
qmat2_vec4 = VEC4_T(packed_weight_tex >> 4);
113120
qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x;
114121
qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y;
115122
qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z;
116123
qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w;
117124

118-
qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
125+
qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F);
119126
qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x;
120127
qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y;
121128
qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z;

0 commit comments

Comments
 (0)