Skip to content

Commit a41a618

Browse files
committed
[ET-VK][ez] Fix 8 bit linear compute shader dispatch
Pull Request resolved: #9531 ## Context Currently, for the `q_8w_linear` shader, both the texture and the buffer variants use the same global work group and local work group setting. Specially, the global work group is set to `{out.numel(), 1, 1}` and the local work group is set to `{64, 1, 1}`. However, I believe this results in a very poor memory re-use for the texture shader. In this configuration: * Within a work group each invocation will be requesting a different row of A - 64 rows of A requested in total * All work groups will be requesting the same row of B * One work group will load 65 unique rows from A and B Compare this to a local work group size of `{8, 8, 1}` * Across the work group, 8 rows will be loaded from A and 8 rows will be loaded from B * One work group will load 16 unique rows total from A and B Evidently, there is better memory re-use in the latter work group as fewer unique rows are loaded. ## Changes Modify the `q_8w_linear` shader to use `{8, 8, 1}` local wg if possible. If `M` is small, then instead use `{4, 16, 1}` or `{2, 32, 1}` to reduce the number of inactive invocations. ghstack-source-id: 274198011 @exported-using-ghexport Differential Revision: [D71706489](https://our.internmc.facebook.com/intern/diff/D71706489/)
1 parent 7159650 commit a41a618

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ void main() {
9090

9191
void main() {
9292
const u16vec2 out_pos = u16vec2(
93-
gl_GlobalInvocationID.x / out_limits.y,
94-
gl_GlobalInvocationID.x % out_limits.y);
95-
if (out_pos.x >= out_limits.x) {
93+
gl_GlobalInvocationID.x,
94+
gl_GlobalInvocationID.y);
95+
96+
if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) {
9697
return;
9798
}
9899

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,37 @@ void add_q_8w_linear_node(
114114
graph.sizes_ubo(mat1_W_packed)});
115115
}
116116

117-
// set global work group size to be 1 dimensional
118-
const utils::uvec3 wg_size = {
119-
static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
117+
utils::uvec3 global_wg;
118+
if (graph.is_buffer_storage(out)) {
119+
global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
120+
} else {
121+
global_wg = graph.logical_limits_of(out_W_packed);
122+
}
123+
124+
utils::uvec3 local_wg{8, 8, 1};
125+
int32_t out_W = graph.size_at<int32_t>(-1, out_W_packed);
126+
127+
if (graph.is_buffer_storage(out_W_packed)) {
128+
local_wg[0] = 64;
129+
local_wg[1] = 1;
130+
local_wg[2] = 1;
131+
} else {
132+
if (out_W % 8 != 0) {
133+
if (out_W % 4 == 0) {
134+
local_wg[0] = 4;
135+
local_wg[1] = 16;
136+
} else {
137+
local_wg[0] = 2;
138+
local_wg[1] = 32;
139+
}
140+
}
141+
}
120142

121143
graph.execute_nodes().emplace_back(new DispatchNode(
122144
graph,
123145
VK_KERNEL_FROM_STR(kernel_name),
124-
wg_size,
125-
graph.create_local_wg_size(wg_size),
146+
global_wg,
147+
local_wg,
126148
// Inputs and Outputs
127149
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
128150
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},

0 commit comments

Comments
 (0)