Skip to content

Commit f8e199c

Browse files
authored
[ET-VK] Removing manual unroll in linear shader to improve overall performance.
Differential Revision: D84571616 Pull Request resolved: #15110
1 parent e1d9fd2 commit f8e199c

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void main() {
6666
return;
6767
}
6868

69-
VEC4_T mat1[TILE_ROWS];
69+
T mat1[TILE_ROWS][4];
7070
VEC4_T qmat2[4][TILE_TXCOLS];
7171
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
7272

@@ -78,7 +78,7 @@ void main() {
7878
scales[${c}] = VEC4_T(
7979
texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0));
8080

81-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
81+
for (int r = 0; r < TILE_ROWS; ++r) {
8282
$for c in range(TILE_TXCOLS):
8383
sums[r][${c}] = VEC4_T(0.0);
8484
}
@@ -91,7 +91,7 @@ void main() {
9191
uint weight_row_txstride = div4(weight_sizes.x);
9292

9393
// Preload weight tensor
94-
[[unroll]] for (int r = 0; r < 4; r++) {
94+
for (int r = 0; r < 4; r++) {
9595
$if QUANT_NBITS == 4:
9696
$for c in range(0, TILE_TXCOLS, 2):
9797
$if WEIGHT_STORAGE == "buffer":
@@ -117,21 +117,28 @@ void main() {
117117
uint in_row_txstride = div4(in_sizes.x);
118118

119119
// Preload input tensor
120-
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
120+
for (int i = 0; i < TILE_ROWS; i++) {
121121
$if IN_STORAGE == "buffer":
122-
mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos];
122+
VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
123+
mat1[i][0] = tmp.x;
124+
mat1[i][1] = tmp.y;
125+
mat1[i][2] = tmp.z;
126+
mat1[i][3] = tmp.w;
123127
$else:
124-
mat1[i] = VEC4_T(
125-
texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
128+
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
129+
mat1[i][0] = tmp.x;
130+
mat1[i][1] = tmp.y;
131+
mat1[i][2] = tmp.z;
132+
mat1[i][3] = tmp.w;
126133
}
127134

128135
// Accumulate output
129-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
136+
for (int r = 0; r < TILE_ROWS; ++r) {
130137
$for c in range(TILE_TXCOLS):
131-
sums[r][${c}] += mat1[r].x * qmat2[0][${c}] +
132-
mat1[r].y * qmat2[1][${c}] +
133-
mat1[r].z * qmat2[2][${c}] +
134-
mat1[r].w * qmat2[3][${c}];
138+
sums[r][${c}] += mat1[r][0] * qmat2[0][${c}] +
139+
mat1[r][1] * qmat2[1][${c}] +
140+
mat1[r][2] * qmat2[2][${c}] +
141+
mat1[r][3] * qmat2[3][${c}];
135142
}
136143
}
137144

@@ -140,7 +147,7 @@ void main() {
140147
uint out_bufi;
141148
uint out_row_txstride = div4(out_sizes.x);
142149

143-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
150+
for (int r = 0; r < TILE_ROWS; ++r) {
144151
$for c in range(TILE_TXCOLS):
145152
$if OUT_STORAGE == "buffer":
146153
if (out_row + r < out_sizes.y) {

0 commit comments

Comments
 (0)