Skip to content

Commit f40f187

Browse files
ahmtoxfacebook-github-bot
authored andcommitted
Add SCALE_DTYPE and ZP_DTYPE support for quantization shaders
Summary: This change adds support for parameterized SCALE_DTYPE and ZP_DTYPE to the quantization and dequantization shaders. This is necessary as when exporting llama with "8da4w" you might have different affine calls with various scale and zero point dtypes. I've also added functionality to automatically populate optional parameters. NOTE: Disable the fusion for linear_qta8a_qga4w as the bug for why it doesn't work with exporting llama is being resolved. **Key Changes:** (1) **YAML Configuration Updates:** - Added SCALE_DTYPE and ZP_DTYPE parameters to quantize_texture.yaml and dequantize_texture.yaml - Added generate_variant_forall entries for SCALE_DTYPE (float) and ZP_DTYPE (int8, int32, float) - This enables shader variants for different scale and zero_point data types (2) **GLSL Shader Updates:** - Added SCALE_T and ZP_T type definitions using the new parameters - Updated tensor declarations to use parameterized types instead of hardcoded "float" and "int" - Added proper type casting (float() and int()) for all scale and zero_point accesses - Added required extensions for SCALE_DTYPE and ZP_DTYPE (3) **C++ Implementation Updates:** - Added dtype suffixes for scale and zero_point in all quantize/dequantize node functions - Added comprehensive data type validation in all implementation functions: - Scale tensors: fp32 only (for now) - Zero point tensors: int32, int8, fp32 - Updated Quantize.cpp, Dequantize.cpp, and ChooseQParams.cpp with consistent validation This change resolves shader compilation errors and enables more flexible quantization strategies by supporting multiple data types for quantization parameters. Differential Revision: D79835267
1 parent f7ddbde commit f40f187

18 files changed

+647
-172
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
502502
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
503503
if qta8a_qga4w_details is not None:
504504
group_size, weight_bits = qta8a_qga4w_details
505-
fuse_into_linear_qta8a_qga4w_node(
506-
self.program, graph_module, node, group_size, weight_bits
507-
)
505+
# fuse_into_linear_qta8a_qga4w_node(
506+
# self.program, graph_module, node, group_size, weight_bits
507+
# )
508508
continue
509509

510510
graph_module.recompile()

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
#define PRECISION ${PRECISION}
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)}
15+
#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)}
1416

1517
#define ${MODE}
1618

1719
${define_active_storage_type("buffer")}
1820
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(SCALE_OUT_DTYPE)}
22+
${define_required_extensions(ZP_OUT_DTYPE)}
1923

2024
#extension GL_EXT_control_flow_attributes : require
2125

2226
layout(std430) buffer;
2327

24-
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
25-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
28+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")}
29+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")}
2630
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2731

2832
$if MODE == "per_tensor":
@@ -254,8 +258,8 @@ void choose_qparams_per_tensor() {
254258
// Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant
255259
calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val);
256260

257-
t_scale[0] = scale_val;
258-
t_zero_point[0] = zero_point_val;
261+
t_scale[0] = SCALE_OUT_T(scale_val);
262+
t_zero_point[0] = ZP_OUT_T(zero_point_val);
259263
}
260264
}
261265

@@ -306,8 +310,8 @@ void choose_qparams_per_token() {
306310
calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val);
307311

308312
// Write results
309-
t_scale[token_id] = scale_val;
310-
t_zero_point[token_id] = zero_point_val;
313+
t_scale[token_id] = SCALE_OUT_T(scale_val);
314+
t_zero_point[token_id] = ZP_OUT_T(zero_point_val);
311315
}
312316
}
313317

@@ -380,12 +384,12 @@ void choose_qparams_block_wise() {
380384
hi = 0.0;
381385
}
382386

383-
float scale;
384-
int zp;
385-
calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp);
387+
float scale_val;
388+
int zero_point_val;
389+
calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val);
386390

387-
t_zero_point[block_id] = zp;
388-
t_scale[block_id] = scale;
391+
t_scale[block_id] = SCALE_OUT_T(scale_val);
392+
t_zero_point[block_id] = ZP_OUT_T(zero_point_val);
389393
}
390394
}
391395

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
choose_qparams_buffer:
22
parameter_names_with_default_values:
33
IN_DTYPE: float
4+
SCALE_OUT_DTYPE: float
5+
ZP_OUT_DTYPE: int32
46
MODE: per_tensor
57
generate_variant_forall:
68
IN_DTYPE:
79
- VALUE: float
10+
SCALE_OUT_DTYPE:
11+
- VALUE: float
12+
ZP_OUT_DTYPE:
13+
- VALUE: int32
14+
- VALUE: int8
15+
- VALUE: float
816
shader_variants:
917
- NAME: choose_qparams_tensor_buffer
1018
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,26 @@
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414
#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")}
15+
#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)}
16+
#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)}
1517

1618
#define ${MODE}
1719

1820
${define_active_storage_type("texture3d")}
1921
${define_required_extensions(IN_DTYPE)}
22+
${define_required_extensions(SCALE_OUT_DTYPE)}
23+
${define_required_extensions(ZP_OUT_DTYPE)}
2024

2125
#extension GL_EXT_control_flow_attributes : require
2226

2327
layout(std430) buffer;
2428

2529
$if MODE != "block_wise":
26-
${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")}
27-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")}
30+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")}
31+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")}
2832
$else:
29-
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
30-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
33+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")}
34+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")}
3135

3236
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3337

@@ -273,8 +277,8 @@ void choose_qparams_per_tensor() {
273277
int zero_point_val;
274278
calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val);
275279

276-
write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
277-
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
280+
write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0));
281+
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0));
278282
}
279283
}
280284

@@ -419,8 +423,8 @@ void choose_qparams_per_token() {
419423
uint out_x = out_remainder % uint(t_scale_limits.x);
420424
ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z));
421425

422-
write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0));
423-
write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0));
426+
write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0));
427+
write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0));
424428
}
425429

426430
// Synchronize before processing next token
@@ -517,8 +521,8 @@ void choose_qparams_block_wise() {
517521
calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp);
518522

519523
// Write the scalar values directly to buffer using linear index
520-
t_scale[blkIdx] = scale;
521-
t_zero_point[blkIdx] = zp;
524+
t_scale[blkIdx] = SCALE_OUT_T(scale);
525+
t_zero_point[blkIdx] = ZP_OUT_T(zp);
522526
}
523527
}
524528

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
choose_qparams_texture:
22
parameter_names_with_default_values:
33
IN_DTYPE: float
4+
SCALE_OUT_DTYPE: float
5+
ZP_OUT_DTYPE: int32
46
MODE: per_tensor
57
generate_variant_forall:
68
IN_DTYPE:
79
- VALUE: float
10+
SCALE_OUT_DTYPE:
11+
- VALUE: float
12+
ZP_OUT_DTYPE:
13+
- VALUE: int32
14+
- VALUE: int8
15+
- VALUE: float
816
shader_variants:
917
- NAME: choose_qparams_tensor_texture3d
1018
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
15+
#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)}
16+
#define ZP_T ${buffer_scalar_type(ZP_DTYPE)}
1517

1618
#define ${MODE}
1719

1820
${define_active_storage_type("buffer")}
1921
${define_required_extensions(IN_DTYPE)}
2022
${define_required_extensions(OUT_DTYPE)}
23+
${define_required_extensions(SCALE_DTYPE)}
24+
${define_required_extensions(ZP_DTYPE)}
2125

2226
layout(std430) buffer;
2327

@@ -27,25 +31,25 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2731
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2832

2933
$if MODE == "per_tensor":
30-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
31-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
34+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
35+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
3236

3337
layout(push_constant) uniform restrict Block {
3438
int quant_min;
3539
int quant_max;
3640
};
3741
$if MODE == "per_token":
38-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
39-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
42+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
43+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
4044

4145
layout(push_constant) uniform restrict Block {
4246
int num_tokens;
4347
int quant_min;
4448
int quant_max;
4549
};
4650
$if MODE == "per_channel":
47-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
48-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
51+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
52+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
4953

5054
layout(push_constant) uniform restrict Block {
5155
int axis;
@@ -54,8 +58,8 @@ $if MODE == "per_channel":
5458
int quant_max;
5559
};
5660
$if MODE == "block_wise":
57-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
58-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
61+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
62+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
5963

6064
layout(push_constant) uniform restrict Block {
6165
ivec4 blockSize; // bW, bH, bC, bN
@@ -150,7 +154,7 @@ void dequantize_per_tensor() {
150154
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
151155

152156
IN_T qvalue = t_in[in_bufi];
153-
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
157+
OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0]));
154158

155159
t_out[out_bufi] = value;
156160
}
@@ -185,7 +189,7 @@ void dequantize_per_token() {
185189

186190
token_idx = min(token_idx, num_tokens - 1);
187191

188-
OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
192+
OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx]));
189193

190194
t_out[out_bufi] = value;
191195
}
@@ -224,7 +228,7 @@ void dequantize_per_channel() {
224228

225229
channel_idx = min(channel_idx, num_channels - 1);
226230

227-
OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);
231+
OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx]));
228232

229233
t_out[out_bufi] = value;
230234
}
@@ -247,7 +251,7 @@ void dequantize_block_wise() {
247251

248252
const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
249253

250-
const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);
254+
const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id]));
251255

252256
t_out[out_bufi] = value;
253257
}

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ dequantize_buffer:
22
parameter_names_with_default_values:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
5+
SCALE_DTYPE: float
6+
ZP_DTYPE: int32
57
MODE: per_tensor
68
generate_variant_forall:
79
IN_DTYPE:
@@ -12,6 +14,12 @@ dequantize_buffer:
1214
- VALUE: half
1315
- VALUE: float
1416
- VALUE: double
17+
SCALE_DTYPE:
18+
- VALUE: float
19+
ZP_DTYPE:
20+
- VALUE: int8
21+
- VALUE: int32
22+
- VALUE: float
1523
shader_variants:
1624
- NAME: dequantize_per_tensor_buffer
1725
MODE: per_tensor

0 commit comments

Comments
 (0)