Skip to content

Commit e918ec2

Browse files
authored
[ET-VK][ez] Fix 8 bit linear compute shader dispatch
Differential Revision: D71706489 Pull Request resolved: #9531
1 parent 1e5c0d4 commit e918ec2

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ void main() {
9090

9191
void main() {
9292
const u16vec2 out_pos = u16vec2(
93-
gl_GlobalInvocationID.x / out_limits.y,
94-
gl_GlobalInvocationID.x % out_limits.y);
95-
if (out_pos.x >= out_limits.x) {
93+
gl_GlobalInvocationID.x,
94+
gl_GlobalInvocationID.y);
95+
96+
if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) {
9697
return;
9798
}
9899

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,37 @@ void add_q_8w_linear_node(
114114
graph.sizes_ubo(mat1_W_packed)});
115115
}
116116

117-
// set global work group size to be 1 dimensional
118-
const utils::uvec3 wg_size = {
119-
static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
117+
utils::uvec3 global_wg;
118+
if (graph.is_buffer_storage(out)) {
119+
global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
120+
} else {
121+
global_wg = graph.logical_limits_of(out_W_packed);
122+
}
123+
124+
utils::uvec3 local_wg{8, 8, 1};
125+
int32_t out_W = graph.size_at<int32_t>(-1, out_W_packed);
126+
127+
if (graph.is_buffer_storage(out_W_packed)) {
128+
local_wg[0] = 64;
129+
local_wg[1] = 1;
130+
local_wg[2] = 1;
131+
} else {
132+
if (out_W % 8 != 0) {
133+
if (out_W % 4 == 0) {
134+
local_wg[0] = 4;
135+
local_wg[1] = 16;
136+
} else {
137+
local_wg[0] = 2;
138+
local_wg[1] = 32;
139+
}
140+
}
141+
}
120142

121143
graph.execute_nodes().emplace_back(new DispatchNode(
122144
graph,
123145
VK_KERNEL_FROM_STR(kernel_name),
124-
wg_size,
125-
graph.create_local_wg_size(wg_size),
146+
global_wg,
147+
local_wg,
126148
// Inputs and Outputs
127149
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
128150
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},

0 commit comments

Comments
 (0)