Skip to content

Commit bbb913b

Browse files
authored
Add SCALE_DTYPE and ZP_DTYPE support for quantization shaders
Differential Revision: D79835267 Pull Request resolved: #13225
1 parent db813fa commit bbb913b

18 files changed

+595
-118
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
499499
continue
500500

501501
# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502-
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
502+
qta8a_qga4w_details = None
503503
if qta8a_qga4w_details is not None:
504504
group_size, weight_bits = qta8a_qga4w_details
505505
fuse_into_linear_qta8a_qga4w_node(

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
#define PRECISION ${PRECISION}
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)}
15+
#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)}
1416

1517
#define ${MODE}
1618

1719
${define_active_storage_type("buffer")}
1820
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(SCALE_OUT_DTYPE)}
22+
${define_required_extensions(ZP_OUT_DTYPE)}
1923

2024
#extension GL_EXT_control_flow_attributes : require
2125

2226
layout(std430) buffer;
2327

24-
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
25-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
28+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")}
29+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")}
2630
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2731

2832
$if MODE == "per_tensor":
@@ -254,8 +258,8 @@ void choose_qparams_per_tensor() {
254258
// Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant
255259
calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val);
256260

257-
t_scale[0] = scale_val;
258-
t_zero_point[0] = zero_point_val;
261+
t_scale[0] = SCALE_OUT_T(scale_val);
262+
t_zero_point[0] = ZP_OUT_T(zero_point_val);
259263
}
260264
}
261265

@@ -306,8 +310,8 @@ void choose_qparams_per_token() {
306310
calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val);
307311

308312
// Write results
309-
t_scale[token_id] = scale_val;
310-
t_zero_point[token_id] = zero_point_val;
313+
t_scale[token_id] = SCALE_OUT_T(scale_val);
314+
t_zero_point[token_id] = ZP_OUT_T(zero_point_val);
311315
}
312316
}
313317

@@ -380,12 +384,12 @@ void choose_qparams_block_wise() {
380384
hi = 0.0;
381385
}
382386

383-
float scale;
384-
int zp;
385-
calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp);
387+
float scale_val;
388+
int zero_point_val;
389+
calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val);
386390

387-
t_zero_point[block_id] = zp;
388-
t_scale[block_id] = scale;
391+
t_scale[block_id] = SCALE_OUT_T(scale_val);
392+
t_zero_point[block_id] = ZP_OUT_T(zero_point_val);
389393
}
390394
}
391395

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
choose_qparams_buffer:
22
parameter_names_with_default_values:
33
IN_DTYPE: float
4+
SCALE_OUT_DTYPE: float
5+
ZP_OUT_DTYPE: int32
46
MODE: per_tensor
57
generate_variant_forall:
68
IN_DTYPE:
79
- VALUE: float
10+
SCALE_OUT_DTYPE:
11+
- VALUE: float
12+
ZP_OUT_DTYPE:
13+
- VALUE: int32
14+
- VALUE: int8
15+
- VALUE: float
816
shader_variants:
917
- NAME: choose_qparams_tensor_buffer
1018
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,26 @@
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414
#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")}
15+
#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)}
16+
#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)}
1517

1618
#define ${MODE}
1719

1820
${define_active_storage_type("texture3d")}
1921
${define_required_extensions(IN_DTYPE)}
22+
${define_required_extensions(SCALE_OUT_DTYPE)}
23+
${define_required_extensions(ZP_OUT_DTYPE)}
2024

2125
#extension GL_EXT_control_flow_attributes : require
2226

2327
layout(std430) buffer;
2428

2529
$if MODE != "block_wise":
26-
${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")}
27-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")}
30+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")}
31+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")}
2832
$else:
29-
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
30-
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
33+
${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")}
34+
${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")}
3135

3236
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3337

@@ -273,8 +277,8 @@ void choose_qparams_per_tensor() {
273277
int zero_point_val;
274278
calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val);
275279

276-
write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
277-
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
280+
write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0));
281+
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0));
278282
}
279283
}
280284

@@ -419,8 +423,8 @@ void choose_qparams_per_token() {
419423
uint out_x = out_remainder % uint(t_scale_limits.x);
420424
ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z));
421425

422-
write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0));
423-
write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0));
426+
write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0));
427+
write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0));
424428
}
425429

426430
// Synchronize before processing next token
@@ -517,8 +521,8 @@ void choose_qparams_block_wise() {
517521
calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp);
518522

519523
// Write the scalar values directly to buffer using linear index
520-
t_scale[blkIdx] = scale;
521-
t_zero_point[blkIdx] = zp;
524+
t_scale[blkIdx] = SCALE_OUT_T(scale);
525+
t_zero_point[blkIdx] = ZP_OUT_T(zp);
522526
}
523527
}
524528

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
choose_qparams_texture:
22
parameter_names_with_default_values:
33
IN_DTYPE: float
4+
SCALE_OUT_DTYPE: float
5+
ZP_OUT_DTYPE: int32
46
MODE: per_tensor
57
generate_variant_forall:
68
IN_DTYPE:
79
- VALUE: float
10+
SCALE_OUT_DTYPE:
11+
- VALUE: float
12+
ZP_OUT_DTYPE:
13+
- VALUE: int32
14+
- VALUE: int8
15+
- VALUE: float
816
shader_variants:
917
- NAME: choose_qparams_tensor_texture3d
1018
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
15+
#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)}
16+
#define ZP_T ${buffer_scalar_type(ZP_DTYPE)}
1517

1618
#define ${MODE}
1719

1820
${define_active_storage_type("buffer")}
1921
${define_required_extensions(IN_DTYPE)}
2022
${define_required_extensions(OUT_DTYPE)}
23+
${define_required_extensions(SCALE_DTYPE)}
24+
${define_required_extensions(ZP_DTYPE)}
2125

2226
layout(std430) buffer;
2327

@@ -27,25 +31,25 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2731
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2832

2933
$if MODE == "per_tensor":
30-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
31-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
34+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
35+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
3236

3337
layout(push_constant) uniform restrict Block {
3438
int quant_min;
3539
int quant_max;
3640
};
3741
$if MODE == "per_token":
38-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
39-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
42+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
43+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
4044

4145
layout(push_constant) uniform restrict Block {
4246
int num_tokens;
4347
int quant_min;
4448
int quant_max;
4549
};
4650
$if MODE == "per_channel":
47-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
48-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
51+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
52+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
4953

5054
layout(push_constant) uniform restrict Block {
5155
int axis;
@@ -54,8 +58,8 @@ $if MODE == "per_channel":
5458
int quant_max;
5559
};
5660
$if MODE == "block_wise":
57-
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
58-
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
61+
${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")}
62+
${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")}
5963

6064
layout(push_constant) uniform restrict Block {
6165
ivec4 blockSize; // bW, bH, bC, bN
@@ -150,7 +154,7 @@ void dequantize_per_tensor() {
150154
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
151155

152156
IN_T qvalue = t_in[in_bufi];
153-
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
157+
OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0]));
154158

155159
t_out[out_bufi] = value;
156160
}
@@ -185,7 +189,7 @@ void dequantize_per_token() {
185189

186190
token_idx = min(token_idx, num_tokens - 1);
187191

188-
OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
192+
OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx]));
189193

190194
t_out[out_bufi] = value;
191195
}
@@ -224,7 +228,7 @@ void dequantize_per_channel() {
224228

225229
channel_idx = min(channel_idx, num_channels - 1);
226230

227-
OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);
231+
OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx]));
228232

229233
t_out[out_bufi] = value;
230234
}
@@ -247,7 +251,7 @@ void dequantize_block_wise() {
247251

248252
const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
249253

250-
const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);
254+
const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id]));
251255

252256
t_out[out_bufi] = value;
253257
}

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ dequantize_buffer:
22
parameter_names_with_default_values:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
5+
SCALE_DTYPE: float
6+
ZP_DTYPE: int32
57
MODE: per_tensor
68
generate_variant_forall:
79
IN_DTYPE:
@@ -12,6 +14,12 @@ dequantize_buffer:
1214
- VALUE: half
1315
- VALUE: float
1416
- VALUE: double
17+
SCALE_DTYPE:
18+
- VALUE: float
19+
ZP_DTYPE:
20+
- VALUE: int8
21+
- VALUE: int32
22+
- VALUE: float
1523
shader_variants:
1624
- NAME: dequantize_per_tensor_buffer
1725
MODE: per_tensor

0 commit comments

Comments
 (0)