Skip to content

Commit a54693c

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Improving 4bit quant mat mul performance by shifting position of -8 operation. (#15436)
Summary: This diff introduces performance improvements in the 4-bit quant matrix multiplication operation by adjusting the position of the -8 operation. Resulting in overall reduction in math operation performed during shader runtime. The thinking here is as follows: * The 4 bit integer weights are unsigned ranging from 0 - 15, and thus to get unsigned number 8 is subtracted from the input. * Assume WS[] is array of signed weights, M[] is matrix, S is the sum The main loop essentially performs: S += ( WS[i] - 8 ) * M[i], for i = [0, N) * This equation can rewritten as: S += WS[i] * M[i] - 8 * M[i], for i = [0, N) * 8 * M[i] need not be performed in the main loop. Also 8 * M[i], for i = [0, N) Can be substituted with A += M[i], for i = [0, N) and A *= 8 Thus, splitting parts of this equation results in a significant reduction in math ops while producing the same result. Differential Revision: D85721578
1 parent 3485495 commit a54693c

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ void main() {
6464

6565
T sums[TILE_ROWS * TILE_TXCOLS * 4];
6666

67+
$if QUANT_NBITS == 4:
68+
// accumulate mat1 elements sum so -8 bias can be applied using it later.
69+
T mat1_accum[TILE_ROWS];
70+
$for r in range(TILE_ROWS):
71+
mat1_accum[${r}] = T(0.0);
72+
6773
for (int r = 0; r < TILE_ROWS; ++r) {
6874
$for c in range(TILE_TXCOLS):
6975
$for j in range(4):
@@ -86,6 +92,10 @@ void main() {
8692
VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
8793
$for j in range(4):
8894
mat1[i * 4 + ${j}] = mat1_vec4[${j}];
95+
96+
$if QUANT_NBITS == 4:
97+
// Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number.
98+
mat1_accum[i] += mat1[i * 4 + 0] + mat1[i * 4 + 1] + mat1[i * 4 + 2] + mat1[i * 4 + 3];
8999
}
90100

91101
$if WEIGHT_STORAGE == "buffer":
@@ -109,13 +119,13 @@ void main() {
109119
packed_weight_tex = texelFetch(
110120
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
111121

112-
qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0);
122+
qmat2_vec4 = VEC4_T(packed_weight_tex >> 4);
113123
qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x;
114124
qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y;
115125
qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z;
116126
qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w;
117127

118-
qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
128+
qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F);
119129
qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x;
120130
qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y;
121131
qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z;
@@ -156,10 +166,16 @@ void main() {
156166
for (int r = 0; r < TILE_ROWS; ++r) {
157167
VEC4_T scaled_sums;
158168
$for c in range(TILE_TXCOLS):
159-
scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x;
160-
scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y;
161-
scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z;
162-
scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w;
169+
$if QUANT_NBITS == 4:
170+
scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] + mat1_accum[r] * -8.0) * scales[${c}].x;
171+
scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] + mat1_accum[r] * -8.0) * scales[${c}].y;
172+
scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] + mat1_accum[r] * -8.0) * scales[${c}].z;
173+
scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] + mat1_accum[r] * -8.0) * scales[${c}].w;
174+
$else:
175+
scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x;
176+
scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y;
177+
scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z;
178+
scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w;
163179

164180
$if OUT_STORAGE == "buffer":
165181
if (out_row + r < out_sizes.y) {

0 commit comments

Comments
 (0)