Skip to content

Commit 35c1f52

Browse files
authored
[ET-VK] Add 'half' variants to some Llama operators + enable llama vulkan export with force_fp16 flag (#14245)
Title says it all! Differential Revision: [D82234179](https://our.internmc.facebook.com/intern/diff/D82234179/)
1 parent 9a9db14 commit 35c1f52

10 files changed

+27
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#define MAX_THREADS 256
2020

2121
${define_active_storage_type(STORAGE)}
22+
23+
${define_required_extensions(DTYPE)}
2224
${define_required_extensions("int8")}
2325

2426
#extension GL_EXT_control_flow_attributes : require
@@ -126,8 +128,8 @@ void find_min_max_for_row(const int output_y) {
126128
const int X4 = div_4(input_sizes.x);
127129

128130
// Initialize thread-local min/max
129-
float local_min = 1e30;
130-
float local_max = -1e30;
131+
T local_min = T(1e30);
132+
T local_max = T(-1e30);
131133

132134
// Each thread processes elements along their assigned output_id with stride
133135
// NUM_WORKERS_PER_OUTPUT
@@ -187,7 +189,7 @@ void main() {
187189
calculate_scale_and_zero_point(
188190
local_min, local_max, quant_min, quant_max, scale, zero_point);
189191

190-
scales_out[i] = scale;
192+
scales_out[i] = T(scale);
191193
zps_out[i] = zero_point;
192194
}
193195
}

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ choose_qparams_per_row:
1414
- VALUE: buffer
1515
DTYPE:
1616
- VALUE: float
17+
- VALUE: half
1718
shader_variants:
1819
- NAME: choose_qparams_per_row

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ linear_dq8ca_q4gsw_tiled:
1616
generate_variant_forall:
1717
DTYPE:
1818
- VALUE: float
19+
- VALUE: half
1920
shader_variants:
2021
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d
2122
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void accumulate_out_tile_with_int_accum_from_int4_weights(
7777

7878
out_tile.data[m][n4] =
7979
fma(VEC4_T(accum_adjusted),
80-
input_scale_m * weight_scales.data[n4],
80+
VEC4_T(input_scale_m * weight_scales.data[n4]),
8181
out_tile.data[m][n4]);
8282
}
8383
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum(
7575
input_zp_vec * weight_sums.data[n4] + accum.data[m][n4];
7676
out_tile.data[m][n4] =
7777
fma(VEC4_T(accum_adjusted),
78-
input_q_scale * weight_scales.data[0],
78+
VEC4_T(input_q_scale * weight_scales.data[0]),
7979
out_tile.data[m][n4]);
8080
}
8181
}
@@ -98,7 +98,7 @@ void accumulate_out_tile_with_int_accum(
9898
input_zp_vec * weight_sums.data[n4] + accum.data[m][n4];
9999
out_tile.data[m][n4] =
100100
fma(VEC4_T(accum_adjusted),
101-
input_q_scale * weight_scales.data[n4],
101+
VEC4_T(input_q_scale * weight_scales.data[n4]),
102102
out_tile.data[m][n4]);
103103
out_tile.data[m][n4] += bias.data[n4];
104104
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
${define_active_storage_type(STORAGE)}
1818

19+
${define_required_extensions(DTYPE)}
20+
1921
#extension GL_EXT_control_flow_attributes : require
2022

2123
layout(std430) buffer;
@@ -85,7 +87,7 @@ void main() {
8587
}
8688

8789
// Initialize thread-local min/max
88-
T local_exp_sum = 0;
90+
T local_exp_sum = T(0);
8991

9092
const int context_len_aligned_down = context_len - mod_4(context_len);
9193
const int C4_limit = div_4(context_len_aligned_down);

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ sdpa_attn_weights_softmax:
1414
- VALUE: buffer
1515
DTYPE:
1616
- VALUE: float
17+
- VALUE: half
1718
shader_variants:
1819
- NAME: sdpa_attn_weights_softmax

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def build_args_parser() -> argparse.ArgumentParser:
418418
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
419419
)
420420
parser.add_argument("-V", "--vulkan", action="store_true")
421+
parser.add_argument("--vulkan-force-fp16", action="store_true")
421422
parser.add_argument("--mps", action="store_true")
422423
parser.add_argument("--coreml", action="store_true")
423424
parser.add_argument(
@@ -885,6 +886,7 @@ def _to_edge_and_lower_llama( # noqa: C901
885886
use_kv_cache: bool = False,
886887
embedding_quantize: Optional[str] = None,
887888
pt2e_quantize: Optional[str] = None,
889+
vulkan_force_fp16: bool = False,
888890
coreml_ios: int = 15,
889891
coreml_quantize: Optional[str] = None,
890892
coreml_compute_units: str = "cpu_only",
@@ -905,6 +907,7 @@ def _to_edge_and_lower_llama( # noqa: C901
905907
get_vulkan_partitioner(
906908
dtype_override,
907909
enable_dynamic_shape,
910+
vulkan_force_fp16,
908911
)
909912
)
910913
modelname = f"vulkan_{modelname}"
@@ -1125,6 +1128,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11251128
if llm_config.quantization.pt2e_quantize
11261129
else None
11271130
),
1131+
vulkan_force_fp16=llm_config.backend.vulkan.force_fp16,
11281132
coreml_ios=llm_config.backend.coreml.ios,
11291133
coreml_quantize=(
11301134
llm_config.backend.coreml.quantize.value

extension/llm/export/config/llm_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ class VulkanConfig:
426426
"""
427427

428428
enabled: bool = False
429+
force_fp16: bool = False
429430

430431

431432
@dataclass
@@ -610,6 +611,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
610611
# Vulkan
611612
if hasattr(args, "vulkan"):
612613
llm_config.backend.vulkan.enabled = args.vulkan
614+
if hasattr(args, "vulkan_force_fp16"):
615+
llm_config.backend.vulkan.force_fp16 = args.vulkan_force_fp16
613616

614617
# QNN
615618
if hasattr(args, "qnn"):

extension/llm/export/partitioner_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True):
3232

3333

3434
def get_vulkan_partitioner(
35-
dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False
35+
dtype_override: Optional[str] = None,
36+
enable_dynamic_shape: bool = False,
37+
force_fp16: bool = False,
3638
):
3739
assert (
3840
dtype_override == "fp32" or dtype_override is None
@@ -41,7 +43,9 @@ def get_vulkan_partitioner(
4143
VulkanPartitioner,
4244
)
4345

44-
return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape})
46+
return VulkanPartitioner(
47+
{"require_dynamic_shapes": enable_dynamic_shape, "force_fp16": force_fp16}
48+
)
4549

4650

4751
def get_mps_partitioner(use_kv_cache: bool = False):

0 commit comments

Comments
 (0)