Skip to content

Commit 219e746

Browse files
authored
[ET-VK] Efficient tiled int8 matmul
Differential Revision: D72066587 Pull Request resolved: #9766
1 parent d73f38f commit 219e746

File tree

5 files changed

+184
-310
lines changed

5 files changed

+184
-310
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl

Lines changed: 0 additions & 212 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
${define_required_extensions(DTYPE)}
19+
20+
$if STORAGE == "buffer":
21+
${define_required_extensions("int8")}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
31+
32+
33+
layout(push_constant) uniform restrict Block {
34+
ivec4 out_sizes;
35+
ivec4 in_sizes;
36+
ivec4 weight_sizes;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
void main() {
42+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43+
const uint out_col = gl_GlobalInvocationID.x << 2;
44+
45+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
46+
return;
47+
}
48+
49+
VEC4_T a[TILE_ROWS];
50+
VEC4_T b[4];
51+
VEC4_T c[TILE_ROWS];
52+
53+
$if STORAGE == "buffer":
54+
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
55+
$else:
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
57+
58+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
59+
c[i] = VEC4_T(0.0);
60+
}
61+
62+
for (int pos = 0; pos < in_sizes.x; pos += 4) {
63+
// Preload weight tensor
64+
[[unroll]] for (int i = 0; i < 4; i++) {
65+
$if STORAGE == "buffer":
66+
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
67+
$else:
68+
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
69+
}
70+
71+
// Preload input tensor
72+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
73+
$if STORAGE == "buffer":
74+
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
75+
$else:
76+
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
77+
}
78+
79+
// Compute partial output
80+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
81+
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
82+
}
83+
}
84+
85+
// Store output tensor
86+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87+
$if STORAGE == "buffer":
88+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
89+
$else:
90+
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
91+
}
92+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_tiled:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
TILE_ROWS: 4
12+
shader_variants:
13+
- NAME: q_8w_linear_tiled_o4x4_texture3d_float
14+
STORAGE: texture3d
15+
TILE_ROWS: 4
16+
- NAME: q_8w_linear_tiled_o4x6_texture3d_float
17+
STORAGE: texture3d
18+
TILE_ROWS: 6

0 commit comments

Comments
 (0)