diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index aa4829d9c90..3d3214bb4ee 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -499,7 +499,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: continue # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization) - qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node) + qta8a_qga4w_details = None if qta8a_qga4w_details is not None: group_size, weight_bits = qta8a_qga4w_details fuse_into_linear_qta8a_qga4w_node( diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index 99a64c3589e..7e21bcf0eba 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -11,18 +11,22 @@ #define PRECISION ${PRECISION} #define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} +#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} #define ${MODE} ${define_active_storage_type("buffer")} ${define_required_extensions(IN_DTYPE)} +${define_required_extensions(SCALE_OUT_DTYPE)} +${define_required_extensions(ZP_OUT_DTYPE)} #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} -${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} +${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": @@ -254,8 +258,8 @@ void choose_qparams_per_tensor() { // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - t_scale[0] = scale_val; - t_zero_point[0] = zero_point_val; + t_scale[0] = SCALE_OUT_T(scale_val); + t_zero_point[0] = ZP_OUT_T(zero_point_val); } } @@ -306,8 +310,8 @@ void choose_qparams_per_token() { calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); // Write results - t_scale[token_id] = scale_val; - t_zero_point[token_id] = zero_point_val; + t_scale[token_id] = SCALE_OUT_T(scale_val); + t_zero_point[token_id] = ZP_OUT_T(zero_point_val); } } @@ -380,12 +384,12 @@ void choose_qparams_block_wise() { hi = 0.0; } - float scale; - int zp; - calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp); + float scale_val; + int zero_point_val; + calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val); - t_zero_point[block_id] = zp; - t_scale[block_id] = scale; + t_scale[block_id] = SCALE_OUT_T(scale_val); + t_zero_point[block_id] = ZP_OUT_T(zero_point_val); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml index ee900750e16..8459b043baa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -1,10 +1,18 @@ choose_qparams_buffer: parameter_names_with_default_values: IN_DTYPE: float + SCALE_OUT_DTYPE: float + ZP_OUT_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: - VALUE: float + SCALE_OUT_DTYPE: + - VALUE: float + ZP_OUT_DTYPE: + - VALUE: int32 + - VALUE: int8 + - VALUE: float shader_variants: - NAME: choose_qparams_tensor_buffer MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 62ea7099f8c..a17a3ae41dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -12,22 +12,26 @@ #define IN_T ${buffer_scalar_type(IN_DTYPE)} #define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} +#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} +#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} #define ${MODE} ${define_active_storage_type("texture3d")} ${define_required_extensions(IN_DTYPE)} +${define_required_extensions(SCALE_OUT_DTYPE)} +${define_required_extensions(ZP_OUT_DTYPE)} #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; $if MODE != "block_wise": - ${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} - ${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} + ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")} + ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")} $else: - ${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} + ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} @@ -273,8 +277,8 @@ void choose_qparams_per_tensor() { int zero_point_val; calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); - write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); + write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); + write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); } } @@ -419,8 +423,8 @@ void choose_qparams_per_token() { uint out_x = out_remainder % uint(t_scale_limits.x); ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); - write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); - write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); + write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); + write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); } // Synchronize before processing next token @@ -517,8 +521,8 @@ void choose_qparams_block_wise() { calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); // Write the scalar values directly to buffer using linear index - t_scale[blkIdx] = scale; - t_zero_point[blkIdx] = zp; + t_scale[blkIdx] = SCALE_OUT_T(scale); + t_zero_point[blkIdx] = ZP_OUT_T(zp); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml index a097ce0da48..12228822d4b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -1,10 +1,18 @@ choose_qparams_texture: parameter_names_with_default_values: IN_DTYPE: float + SCALE_OUT_DTYPE: float + ZP_OUT_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: - VALUE: float + SCALE_OUT_DTYPE: + - VALUE: float + ZP_OUT_DTYPE: + - VALUE: int32 + - VALUE: int8 + - VALUE: float shader_variants: - NAME: choose_qparams_tensor_texture3d MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 43e62eadeee..57dc2d53fff 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -12,12 +12,16 @@ #define IN_T ${buffer_scalar_type(IN_DTYPE)} #define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} +#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} #define ${MODE} ${define_active_storage_type("buffer")} ${define_required_extensions(IN_DTYPE)} ${define_required_extensions(OUT_DTYPE)} +${define_required_extensions(SCALE_DTYPE)} +${define_required_extensions(ZP_DTYPE)} layout(std430) buffer; @@ -27,16 +31,16 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int quant_min; int quant_max; }; $if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int num_tokens; @@ -44,8 +48,8 @@ $if MODE == "per_token": int quant_max; }; $if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int axis; @@ -54,8 +58,8 @@ $if MODE == "per_channel": int quant_max; }; $if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { ivec4 blockSize; // bW, bH, bC, bN @@ -150,7 +154,7 @@ void dequantize_per_tensor() { const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); + OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); t_out[out_bufi] = value; } @@ -185,7 +189,7 @@ void dequantize_per_token() { token_idx = min(token_idx, num_tokens - 1); - OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx])); t_out[out_bufi] = value; } @@ -224,7 +228,7 @@ void dequantize_per_channel() { channel_idx = min(channel_idx, num_channels - 1); - OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]); + OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); t_out[out_bufi] = value; } @@ -247,7 +251,7 @@ void dequantize_block_wise() { const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); t_out[out_bufi] = value; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index 999c59d3b79..a4375038a75 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -2,6 +2,8 @@ dequantize_buffer: parameter_names_with_default_values: IN_DTYPE: int32 OUT_DTYPE: float + SCALE_DTYPE: float + ZP_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: @@ -12,6 +14,12 @@ dequantize_buffer: - VALUE: half - VALUE: float - VALUE: double + SCALE_DTYPE: + - VALUE: float + ZP_DTYPE: + - VALUE: int8 + - VALUE: int32 + - VALUE: float shader_variants: - NAME: dequantize_per_tensor_buffer MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 20bf6c87e26..19276cd8f7f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -15,12 +15,16 @@ #define OUT_T ${buffer_scalar_type(OUT_DTYPE)} #define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} +#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} +#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} #define ${MODE} ${define_active_storage_type("texture3d")} ${define_required_extensions(IN_DTYPE)} ${define_required_extensions(OUT_DTYPE)} +${define_required_extensions(SCALE_DTYPE)} +${define_required_extensions(ZP_DTYPE)} #extension GL_EXT_control_flow_attributes : require @@ -30,16 +34,16 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int quant_min; int quant_max; }; $if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int num_tokens; @@ -47,8 +51,8 @@ $if MODE == "per_token": int quant_max; }; $if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int axis; @@ -57,8 +61,8 @@ $if MODE == "per_channel": int quant_max; }; $if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { ivec4 blockSize; // bW, bH, bC, bN @@ -160,7 +164,7 @@ void dequantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); + OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); $if OUT_DTYPE == "double": outtex[i] = float(value); @@ -196,8 +200,8 @@ void dequantize_per_token() { token_idx = min(token_idx, num_tokens - 1); // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = t_scale[token_idx]; - int zero_point_val = t_zero_point[token_idx]; + float scale_val = float(t_scale[token_idx]); + int zero_point_val = int(t_zero_point[token_idx]); FVEC4_T outtex; [[unroll]] for (int i = 0; i < 4; ++i) { @@ -238,8 +242,8 @@ void dequantize_per_channel() { int channel_idx = pos.x * 4 + i; channel_idx = min(channel_idx, num_channels - 1); - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); $if OUT_DTYPE == "double": outtex[i] = float(value); @@ -249,8 +253,8 @@ void dequantize_per_channel() { } else if (axis == 1) { int channel_idx = pos.y; channel_idx = min(channel_idx, num_channels - 1); - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); @@ -267,8 +271,8 @@ void dequantize_per_channel() { int folded_idx = pos.z; int channel_idx = folded_idx % num_channels; - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); @@ -287,8 +291,8 @@ void dequantize_per_channel() { // the C dimension N(C)HW int channel_idx = folded_idx / num_channels; - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); @@ -326,7 +330,7 @@ void dequantize_block_wise() { int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); $if OUT_DTYPE == "double": outtex[i] = float(value); $else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 9b624762192..7a58e9410d3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -2,6 +2,8 @@ dequantize_texture: parameter_names_with_default_values: IN_DTYPE: int32 OUT_DTYPE: float + SCALE_DTYPE: float + ZP_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: @@ -12,6 +14,12 @@ dequantize_texture: - VALUE: half - VALUE: float - VALUE: double + SCALE_DTYPE: + - VALUE: float + ZP_DTYPE: + - VALUE: int8 + - VALUE: int32 + - VALUE: float shader_variants: - NAME: dequantize_per_tensor_texture3d MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index 9a342d8e057..7bf3a932c6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -12,12 +12,16 @@ #define IN_T ${buffer_scalar_type(IN_DTYPE)} #define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} +#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} #define ${MODE} ${define_active_storage_type("buffer")} ${define_required_extensions(IN_DTYPE)} ${define_required_extensions(OUT_DTYPE)} +${define_required_extensions(SCALE_DTYPE)} +${define_required_extensions(ZP_DTYPE)} layout(std430) buffer; @@ -27,16 +31,16 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int quant_min; int quant_max; }; $if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int num_tokens; @@ -44,8 +48,8 @@ $if MODE == "per_token": int quant_max; }; $if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int axis; @@ -54,8 +58,8 @@ $if MODE == "per_channel": int quant_max; }; $if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { ivec4 blockSize; // bW, bH, bC, bN @@ -144,7 +148,7 @@ void quantize_per_tensor() { const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); IN_T value = t_in[in_bufi]; - OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]); + OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); t_out[out_bufi] = qvalue; } @@ -179,7 +183,7 @@ void quantize_per_token() { token_idx = min(token_idx, num_tokens - 1); - OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]); + OUT_T qvalue = quantize_val(value, float(t_scale[token_idx]), int(t_zero_point[token_idx])); t_out[out_bufi] = qvalue; } @@ -218,7 +222,7 @@ void quantize_per_channel() { channel_idx = min(channel_idx, num_channels - 1); - OUT_T qvalue = quantize_val(value, t_scale[channel_idx], t_zero_point[channel_idx]); + OUT_T qvalue = quantize_val(value, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); t_out[out_bufi] = qvalue; } @@ -241,7 +245,7 @@ void quantize_block_wise() { const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - const OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + const OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); t_out[out_bufi] = qvalue; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 5b479c2f90f..fb5853ecd20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -2,6 +2,8 @@ quantize_buffer: parameter_names_with_default_values: IN_DTYPE: float OUT_DTYPE: int32 + SCALE_DTYPE: float + ZP_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: @@ -12,6 +14,12 @@ quantize_buffer: - VALUE: uint8 - VALUE: int8 - VALUE: int32 + SCALE_DTYPE: + - VALUE: float + ZP_DTYPE: + - VALUE: int8 + - VALUE: int32 + - VALUE: float shader_variants: - NAME: quantize_per_tensor_buffer MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index 69f219ef329..12e5769f50d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -15,12 +15,16 @@ #define OUT_T ${buffer_scalar_type(OUT_DTYPE)} #define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} +#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} +#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} #define ${MODE} ${define_active_storage_type("texture3d")} ${define_required_extensions(IN_DTYPE)} ${define_required_extensions(OUT_DTYPE)} +${define_required_extensions(SCALE_DTYPE)} +${define_required_extensions(ZP_DTYPE)} #extension GL_EXT_control_flow_attributes : require @@ -32,16 +36,16 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int quant_min; int quant_max; }; $if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int num_tokens; @@ -49,8 +53,8 @@ $if MODE == "per_token": int quant_max; }; $if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict Block { int axis; @@ -59,8 +63,8 @@ $if MODE == "per_channel": int quant_max; }; $if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} layout(push_constant) uniform restrict BlockPC { ivec4 blockSize; // WHCN @@ -148,7 +152,7 @@ void quantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]); + OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); outtex[i] = qvalue; } write_texel(t_out, pos, outtex); @@ -180,8 +184,8 @@ void quantize_per_token() { token_idx = min(token_idx, num_tokens - 1); // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = t_scale[token_idx]; - int zero_point_val = t_zero_point[token_idx]; + float scale_val = float(t_scale[token_idx]); + int zero_point_val = int(t_zero_point[token_idx]); IVEC4_T outtex; [[unroll]] for (int i = 0; i < 4; ++i) { @@ -219,8 +223,8 @@ void quantize_per_channel() { int channel_idx = pos.x * 4 + i; channel_idx = min(channel_idx, num_channels - 1); - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); outtex[i] = qvalue; } @@ -228,8 +232,8 @@ void quantize_per_channel() { // Height dimension - all texel components use same channel index int channel_idx = pos.y; channel_idx = min(channel_idx, num_channels - 1); - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T value = IN_T(intex[i]); @@ -243,8 +247,8 @@ void quantize_per_channel() { int folded_idx = pos.z; int channel_idx = folded_idx % num_channels; - float scale_val = t_scale[channel_idx]; - int zero_point_val = t_zero_point[channel_idx]; + float scale_val = float(t_scale[channel_idx]); + int zero_point_val = int(t_zero_point[channel_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T value = IN_T(intex[i]); @@ -258,8 +262,8 @@ void quantize_per_channel() { int folded_idx = pos.z; int batch_idx = folded_idx / num_channels; - float scale_val = t_scale[batch_idx]; - int zero_point_val = t_zero_point[batch_idx]; + float scale_val = float(t_scale[batch_idx]); + int zero_point_val = int(t_zero_point[batch_idx]); [[unroll]] for (int i = 0; i < 4; ++i) { IN_T value = IN_T(intex[i]); @@ -294,7 +298,7 @@ void quantize_block_wise() { int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); outtex[i] = qvalue; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 2e40ac90794..03d418ff2f7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -2,6 +2,8 @@ quantize_texture: parameter_names_with_default_values: IN_DTYPE: float OUT_DTYPE: int32 + SCALE_DTYPE: float + ZP_DTYPE: int32 MODE: per_tensor generate_variant_forall: IN_DTYPE: @@ -12,6 +14,12 @@ quantize_texture: - VALUE: uint8 - VALUE: int8 - VALUE: int32 + SCALE_DTYPE: + - VALUE: float + ZP_DTYPE: + - VALUE: int8 + - VALUE: int32 + - VALUE: float shader_variants: - NAME: quantize_per_tensor_texture3d MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 76d352334e3..2cf837fa89c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -169,9 +169,35 @@ void add_choose_qparams_tensor_node( std::string kernel_name("choose_qparams_tensor"); add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(zero_point_out)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -227,6 +253,8 @@ void add_choose_qparams_per_token_asymmetric_node( std::string kernel_name("choose_qparams_per_token_asymmetric"); add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); // Calculate number of tokens (product of all dimensions except the last one) int64_t num_tokens = 1; @@ -317,9 +345,26 @@ void add_choose_qparams_block_wise_node( num_blocks_vec[0] * num_blocks_vec[1], num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - int qmin = static_cast(graph.get_int(quant_min)); - int qmax = static_cast(graph.get_int(quant_max)); - float eps_val = static_cast(graph.get_double(eps)); + // Handle optional quant_min and quant_max parameters + int qmin, qmax; + if (graph.val_is_none(quant_min) || graph.val_is_none(quant_max)) { + // Use default values based on target_dtype (similar to + // _get_and_check_qmin_qmax) For now, assume int8 range as default - this + // should match the Python implementation + qmin = -128; + qmax = 127; + } else { + qmin = static_cast(graph.get_int(quant_min)); + qmax = static_cast(graph.get_int(quant_max)); + } + + float eps_val; + if (graph.val_is_none(eps)) { + // Use default eps value (similar to Python implementation) + eps_val = 1.192092896e-07f; // torch.finfo(torch.float32).eps + } else { + eps_val = static_cast(graph.get_double(eps)); + } // Create push constants vector std::vector push_constants = { @@ -334,6 +379,8 @@ void add_choose_qparams_block_wise_node( std::string kernel_name("choose_qparams_block_wise"); add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); + add_dtype_suffix(kernel_name, graph.dtype_of(zp_out)); vkapi::ParamsBindList param_ubos; @@ -408,9 +455,18 @@ void choose_qparams_tensor_impl( // Verify input is a floating point type VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + // Get scale and zero point output dtypes + vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); + vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); + + // Verify supported output types for scale (fp32 only for now) + VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); + + // Verify supported output types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_out_dtype == vkapi::kInt || + zero_point_out_dtype == vkapi::kChar || + zero_point_out_dtype == vkapi::kFloat); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -449,9 +505,18 @@ void choose_qparams_per_token_asymmetric_impl( // Verify input is a floating point type VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + // Get scale and zero point output dtypes + vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); + vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); + + // Verify supported output types for scale (fp32 only for now) + VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); + + // Verify supported output types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_out_dtype == vkapi::kInt || + zero_point_out_dtype == vkapi::kChar || + zero_point_out_dtype == vkapi::kFloat); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -499,9 +564,18 @@ void choose_qparams_affine_impl( // Verify input is a floating point type VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + // Get scale and zero point dtypes from arguments + vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); + vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); + + // Verify supported output types for scale (fp32 only for now) + VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); + + // Verify supported output types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_out_dtype == vkapi::kInt || + zero_point_out_dtype == vkapi::kChar || + zero_point_out_dtype == vkapi::kFloat); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -515,12 +589,14 @@ void choose_qparams_affine_impl( std::string mapping_type_str = graph.get_string(mapping_type); int mapping_type_val = 0; // Default to ASYMMETRIC - if (mapping_type_str == "ASYMMETRIC") { - mapping_type_val = 0; + if (mapping_type_str == "ASYMMETRIC" || mapping_type_str.empty()) { + mapping_type_val = 0; // ASYMMETRIC } else if (mapping_type_str == "SYMMETRIC") { mapping_type_val = 1; } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { mapping_type_val = 2; + } else { + VK_THROW("Unsupported mapping_type: ", mapping_type_str); } add_choose_qparams_block_wise_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 0822dcb05f3..a217734653d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -107,9 +107,35 @@ void add_dequantize_per_tensor_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(input)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } vkapi::ParamsBindList param_ubos; std::vector push_constants; @@ -169,9 +195,35 @@ void add_dequantize_per_token_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(input)); + + int quant_min_val, quant_max_val; - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } int num_tokens = static_cast(graph.sizes_of(scale)[0]); @@ -235,10 +287,37 @@ void add_dequantize_per_channel_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); int axis_val = static_cast(graph.get_int(axis)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(input)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } // Normalize axis and convert from NCHW to WHCN using utility functions const auto input_sizes = graph.sizes_of(input); @@ -320,9 +399,35 @@ void add_dequantize_block_wise_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(input)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); @@ -423,6 +528,18 @@ void dequantize_per_tensor_impl( graph.dtype_of(input) == vkapi::kChar || graph.dtype_of(input) == vkapi::kInt); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -470,6 +587,18 @@ void dequantize_per_token_impl( graph.dtype_of(input) == vkapi::kChar || graph.dtype_of(input) == vkapi::kInt); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -545,6 +674,18 @@ void dequantize_per_channel_impl( graph.dtype_of(input) == vkapi::kChar || graph.dtype_of(input) == vkapi::kInt); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -625,6 +766,18 @@ void dequantize_affine_impl( graph.dtype_of(input) == vkapi::kChar || graph.dtype_of(input) == vkapi::kInt); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index d4d0ba30293..88f77261f4f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -108,9 +108,35 @@ void add_quantize_per_tensor_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(output)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } vkapi::ParamsBindList param_ubos; std::vector push_constants; @@ -170,9 +196,35 @@ void add_quantize_per_token_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(output)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } int num_tokens = static_cast(graph.sizes_of(scale)[0]); @@ -243,10 +295,37 @@ void add_quantize_per_channel_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); int axis_val = static_cast(graph.get_int(axis)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(output)); + + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } // Normalize axis and convert from NCHW to WHCN using utility functions const auto input_sizes = graph.sizes_of(input); @@ -336,9 +415,35 @@ void add_quantize_block_wise_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); + add_dtype_suffix(kernel_name, graph.dtype_of(scale)); + add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); + + // Handle optional quant_min and quant_max parameters independently + auto bounds = get_dtype_bounds(graph.dtype_of(output)); - int quant_min_val = static_cast(graph.get_int(quant_min)); - int quant_max_val = static_cast(graph.get_int(quant_max)); + int quant_min_val, quant_max_val; + + // Handle quant_min + if (graph.val_is_none(quant_min)) { + quant_min_val = bounds.first; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_min), + "quant_min must be an integer, got type: ", + graph.get_val_type(quant_min)); + quant_min_val = static_cast(graph.get_int(quant_min)); + } + + // Handle quant_max + if (graph.val_is_none(quant_max)) { + quant_max_val = bounds.second; + } else { + VK_CHECK_COND( + graph.val_is_int(quant_max), + "quant_max must be an integer, got type: ", + graph.get_val_type(quant_max)); + quant_max_val = static_cast(graph.get_int(quant_max)); + } const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); @@ -427,6 +532,8 @@ void quantize_per_tensor_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is a floating point type @@ -435,6 +542,18 @@ void quantize_per_tensor_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + add_quantize_per_tensor_node( graph, input, scale, zero_point, quant_min, quant_max, output); } @@ -466,6 +585,18 @@ void quantize_per_token_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -539,6 +670,18 @@ void quantize_per_channel_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -617,6 +760,18 @@ void quantize_affine_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); + // Get scale and zero point dtypes + vkapi::ScalarType scale_dtype = graph.dtype_of(scale); + vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); + + // Verify supported types for scale (fp32 only for now) + VK_CHECK_COND(scale_dtype == vkapi::kFloat); + + // Verify supported types for zero point (int32, int8, fp32) + VK_CHECK_COND( + zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || + zero_point_dtype == vkapi::kFloat); + // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h index 8e10c4e2bfa..270bdd1cd6b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h @@ -28,4 +28,22 @@ T extract_scalar(const Value& value) { VK_THROW("Cannot extract scalar from Value with type ", value.type()); } +// Helper function to get default quant_min and quant_max based on dtype +// This matches the logic in _get_and_check_qmin_qmax from quant_primitives.py +inline std::pair get_dtype_bounds(vkapi::ScalarType dtype) { + switch (dtype) { + case vkapi::kByte: // uint8 + return {0, 255}; + case vkapi::kChar: // int8 + return {-128, 127}; + case vkapi::kShort: // int16 + return {-(1 << 15), (1 << 15) - 1}; + case vkapi::kInt: // int32 + return {-(1LL << 31), (1LL << 31) - 1}; + default: + // For unsupported types, throw an error instead of assuming int8 + VK_THROW("Unsupported dtype for quantization bounds: ", dtype); + } +} + } // namespace vkcompute diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 4f54bc638ba..6b05890c3c7 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -155,6 +155,9 @@ def test_fuse_linear_qcs4w(self): self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + @unittest.skip( + "linear_qta8a_qga4w currently does not support E2E dynamic quantization" + ) def test_fuse_linear_qta8a_qga4w(self): """Test fusion of dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).""" K = 256