Skip to content

Commit 1d37845

Browse files
authored
[ET-VK] Misc code cleanup to recent Quantized Linear + SDPA implementations (#14197)
As title. Introduce some minor fixes and code cleanup to the recently added dqlinear and sdpa implementations. Differential Revision: [D82120825](https://our.internmc.facebook.com/intern/diff/D82120825/)
1 parent fe9447e commit 1d37845

11 files changed

+26
-44
lines changed

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ layout(push_constant) uniform PushConstants {
4040
int quant_max;
4141
};
4242

43-
#extension GL_EXT_debug_printf : enable
44-
4543
// Shared memory for cooperative min/max finding
4644
shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
4745
shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];

backends/vulkan/runtime/graph/ops/glsl/im2col.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#version 450 core
1010

11-
#extension GL_EXT_debug_printf : enable
12-
1311
#define PRECISION ${PRECISION}
1412
#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)}
1513
#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)}

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ void main() {
110110
IntPerInChannelParams int8_input_sums_tile;
111111

112112
const int num_groups = K4 / K4_per_group;
113+
const int group_size = mul_4(K4_per_group);
113114

114115
for (int group_i = 0; group_i < num_groups; ++group_i) {
115116
// Reset int accumulator
@@ -119,7 +120,6 @@ void main() {
119120

120121
load_int8_input_tile(int8_in_tile, k4, m4, K4);
121122
load_int4_weight_tile(int4_weight_tile, k4, n8, K4);
122-
// load_int4_weight_tile(int4_weight_tile, n8, k4, N8);
123123

124124
int_accumulate_with_int4_weight(
125125
out_accum, int8_in_tile, int4_weight_tile);
@@ -129,13 +129,6 @@ void main() {
129129
load_weight_sums_tile_for_group(weight_sums_tile, n4, group_i, N4);
130130
load_int8_input_sums_tile_for_group(int8_input_sums_tile, m4, group_i, M4);
131131

132-
const int group_size = mul_4(K4_per_group);
133-
134-
// // Update output tile with accumulated values
135-
// accumulate_out_tile_with_int_accum_from_int4_weights_test(
136-
// out_tile,
137-
// out_accum);
138-
139132
accumulate_out_tile_with_int_accum_from_int4_weights(
140133
out_tile,
141134
out_accum,

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ void accumulate_out_tile_with_int_accum_from_int4_weights(
6363
const FPPerOutChannelParams weight_scales,
6464
const int group_size) {
6565
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
66-
float input_scale_m = input_scales.data[0][m];
67-
int input_zp_m = input_zps.data[0][m];
68-
int input_sum_m = input_sums.data[0][m];
66+
const int m4 = div_4(m);
67+
const int m4i = mod_4(m);
68+
69+
float input_scale_m = input_scales.data[m4][m4i];
70+
int input_zp_m = input_zps.data[m4][m4i];
71+
int input_sum_m = input_sums.data[m4][m4i];
6972

7073
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
7174
ivec4 accum_adjusted = accum.data[m][n4] -

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ ${layout_declare_spec_const(C, "int", "K4_per_group", "0")}
7171
shared FPOutTile partial_sums[WGS];
7272

7373
void main() {
74-
const int lid = int(gl_LocalInvocationID.x);
75-
const int n8 = int(gl_GlobalInvocationID.y);
74+
const int lid = int(gl_LocalInvocationID.z);
75+
const int n8 = int(gl_GlobalInvocationID.x);
7676

7777
// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes
7878
// 8 output elements, so each thread will write to 8 elements starting at the

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT];
6060
* the entire work group co-operates to compute one reduction output.
6161
*/
6262

63-
#extension GL_EXT_debug_printf : enable
64-
6563
void main() {
6664
const int worker_id = int(gl_LocalInvocationID.y);
6765

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ ${layout_declare_spec_const(C, "float", "inv_scale", "1.0")}
7373
*
7474
*/
7575

76-
#extension GL_EXT_debug_printf : enable
77-
7876
void main() {
7977
const int tile_idx_x = int(gl_GlobalInvocationID.x);
8078
const int tile_idx_y = int(gl_GlobalInvocationID.y);

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT];
6060
* the entire work group co-operates to compute one reduction output.
6161
*/
6262

63-
#extension GL_EXT_debug_printf : enable
64-
6563
void main() {
6664
const int worker_id = int(gl_LocalInvocationID.y);
6765

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5656
* output has shape (batches, seq_len, num_q_heads, head_dim)
5757
*/
5858

59-
#extension GL_EXT_debug_printf : enable
60-
6159
void main() {
6260
const int tile_idx_x = int(gl_GlobalInvocationID.x);
6361
const int tile_idx_y = int(gl_GlobalInvocationID.y);

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,27 @@ utils::uvec3 quantized_linear_global_wg_size(
6161
const ValueRef out = args.at(0).refs.at(0);
6262

6363
std::vector<int64_t> out_sizes = graph->sizes_of(out);
64-
// height
65-
const uint32_t M = utils::val_at(-2, out_sizes);
6664
// width
6765
const uint32_t N = utils::val_at(-1, out_sizes);
66+
// height
67+
const uint32_t M = utils::val_at(-2, out_sizes);
6868

69-
const uint32_t M4 = utils::div_up(M, 4u);
70-
const uint32_t N4 = utils::div_up(N, 4u);
69+
uint32_t N_per_tile = 4;
70+
uint32_t M_per_tile = 4;
7171

72-
// For 4-bit weights, each output tile contains 8 columns and 4 rows
72+
// For 4-bit weights, each output tile contains 8 columns
7373
if (shader.kernel_name.find("q4") != std::string::npos) {
74-
const uint32_t N8 = utils::div_up(N, 8u);
75-
76-
const bool using_coop_algorithm =
77-
shader.kernel_name.find("_coop") != std::string::npos;
78-
// TODO: explain
79-
if (using_coop_algorithm) {
80-
return {64, N8, M};
81-
}
82-
return {N8, M4, 1};
74+
N_per_tile = 8;
75+
}
76+
if (shader.kernel_name.find("coop") != std::string::npos) {
77+
M_per_tile = 1;
8378
}
8479

80+
const uint32_t num_N_tiles = utils::div_up(N, N_per_tile);
81+
const uint32_t num_M_tiles = utils::div_up(M, M_per_tile);
82+
8583
// Otherwise, each output tile contains 4 columns and 4 rows
86-
return {N4, M4, 1};
84+
return {num_N_tiles, num_M_tiles, 1};
8785
}
8886

8987
utils::uvec3 quantized_linear_local_wg_size(
@@ -96,7 +94,7 @@ utils::uvec3 quantized_linear_local_wg_size(
9694
shader.kernel_name.find("_coop") != std::string::npos;
9795

9896
if (use_coop_algorithm) {
99-
return {64, 1, 1};
97+
return {1, 1, 64};
10098
} else {
10199
return pick_hw_square_wg_size(
102100
graph, shader, global_workgroup_size, args, resize_args);

0 commit comments

Comments
 (0)