Skip to content

Commit 4810bc2

Browse files
committed
Update base for Update on "[ET-VK] Replace Uniform buffers with push constants for copy op"
This diff replaces uniform buffers with push constants for copy op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66890851](https://our.internmc.facebook.com/intern/diff/D66890851/) [ghstack-poisoned]
2 parents 3511b07 + de74961 commit 4810bc2

File tree

13 files changed

+75
-37
lines changed

13 files changed

+75
-37
lines changed

backends/arm/test/ops/test_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def test_layer_norm_tosa_BI(
157157

158158
# Numerical issues on FVP likely due to mul op, MLETORCH-521
159159
# Skip tests that require transposes.
160-
@parameterized.expand(test_data_suite[:-2])
160+
@parameterized.expand(test_data_suite)
161161
@unittest.expectedFailure
162-
def test_layer_norm_u55_BI(
162+
def test_layer_norm_u55_BI_xfails(
163163
self,
164164
test_name: str,
165165
test_data: torch.Tensor,
@@ -171,7 +171,8 @@ def test_layer_norm_u55_BI(
171171

172172
# Numerical issues on FVP likely due to mul op, MLETORCH-521
173173
@parameterized.expand(test_data_suite[:-2])
174-
def test_layer_norm_u85_BI_fvp(
174+
@unittest.expectedFailure
175+
def test_layer_norm_u85_BI_xfails(
175176
self,
176177
test_name: str,
177178
test_data: torch.Tensor,
@@ -182,7 +183,6 @@ def test_layer_norm_u85_BI_fvp(
182183
)
183184

184185
@parameterized.expand(test_data_suite[-2:])
185-
@unittest.skip # Flaky
186186
def test_layer_norm_u85_BI(
187187
self,
188188
test_name: str,

backends/vulkan/docs/android_demo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
5959
# The files will usually be downloaded to ~/.llama
6060
python -m examples.models.llama.export_llama \
6161
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
62+
--model "llama3_2" \
6263
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
6364
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
6465
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,21 @@ void DispatchNode::encode(ComputeGraph* graph) {
7575

7676
bind_params_to_descriptor_set(params_, descriptor_set, idx);
7777

78-
uint8_t push_constants_data[128];
78+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
7979
uint32_t push_constants_offset = 0;
8080

8181
for (const auto& push_constant : push_constants_) {
82-
push_constants_offset +=
83-
push_constant.write(push_constants_data, push_constants_offset, 128);
82+
push_constants_offset += push_constant.write(
83+
push_constants_data.data(),
84+
push_constants_offset,
85+
kMaxPushConstantSize);
8486
}
8587
context->register_shader_dispatch(
8688
descriptor_set,
8789
pipeline_barrier,
8890
shader_,
8991
global_workgroup_size_,
90-
push_constants_data,
92+
push_constants_data.data(),
9193
push_constants_offset);
9294

9395
context->report_shader_dispatch_end();

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace vkcompute {
1818

1919
class ComputeGraph;
2020

21+
constexpr uint32_t kMaxPushConstantSize = 128;
2122
/*
2223
* Represents a push constant data entry
2324
* Which is either shared pointer to a tensor's uniform data with an attribute

devtools/bundled_program/core.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Dict, List, Optional, Sequence, Type, Union
1010

1111
import executorch.devtools.bundled_program.schema as bp_schema
12-
from pyre_extensions import none_throws
1312

1413
import executorch.exir.schema as core_schema
1514

@@ -44,10 +43,12 @@ class BundledProgram:
4443

4544
def __init__(
4645
self,
47-
executorch_program: Optional[Union[
48-
ExecutorchProgram,
49-
ExecutorchProgramManager,
50-
]],
46+
executorch_program: Optional[
47+
Union[
48+
ExecutorchProgram,
49+
ExecutorchProgramManager,
50+
]
51+
],
5152
method_test_suites: Sequence[MethodTestSuite],
5253
pte_file_path: Optional[str] = None,
5354
):
@@ -59,18 +60,24 @@ def __init__(
5960
pte_file_path: The path to pte file to deserialize program if executorch_program is not provided.
6061
"""
6162
if not executorch_program and not pte_file_path:
62-
raise RuntimeError("Either executorch_program or pte_file_path must be provided")
63+
raise RuntimeError(
64+
"Either executorch_program or pte_file_path must be provided"
65+
)
6366

6467
if executorch_program and pte_file_path:
65-
raise RuntimeError("Only one of executorch_program or pte_file_path can be used")
68+
raise RuntimeError(
69+
"Only one of executorch_program or pte_file_path can be used"
70+
)
6671

6772
method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
6873
if executorch_program:
6974
self._assert_valid_bundle(executorch_program, method_test_suites)
70-
self.executorch_program: Optional[Union[
71-
ExecutorchProgram,
72-
ExecutorchProgramManager,
73-
]] = executorch_program
75+
self.executorch_program: Optional[
76+
Union[
77+
ExecutorchProgram,
78+
ExecutorchProgramManager,
79+
]
80+
] = executorch_program
7481
self._pte_file_path: Optional[str] = pte_file_path
7582

7683
self.method_test_suites = method_test_suites
@@ -88,7 +95,8 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
8895
if self.executorch_program:
8996
program = self._extract_program(self.executorch_program)
9097
else:
91-
with open(none_throws(self._pte_file_path), "rb") as f:
98+
assert self._pte_file_path is not None
99+
with open(self._pte_file_path, "rb") as f:
92100
p_bytes = f.read()
93101
program = _deserialize_pte_binary(p_bytes)
94102

devtools/bundled_program/test/test_bundle_data.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9+
import tempfile
910
import unittest
1011
from typing import List
11-
import tempfile
12+
1213
import executorch.devtools.bundled_program.schema as bp_schema
1314

1415
import torch
@@ -73,7 +74,7 @@ def test_bundled_program(self) -> None:
7374
bundled_program.serialize_to_schema().program,
7475
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
7576
)
76-
77+
7778
def test_bundled_program_from_pte(self) -> None:
7879
executorch_program, method_test_suites = get_common_executorch_program()
7980

@@ -82,11 +83,17 @@ def test_bundled_program_from_pte(self) -> None:
8283
with open(executorch_model_path, "wb") as f:
8384
f.write(executorch_program.buffer)
8485

85-
bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)
86+
bundled_program = BundledProgram(
87+
executorch_program=None,
88+
method_test_suites=method_test_suites,
89+
pte_file_path=executorch_model_path,
90+
)
8691

8792
method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)
8893

89-
for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
94+
for plan_id in range(
95+
len(executorch_program.executorch_program.execution_plan)
96+
):
9097
bundled_plan_test = (
9198
bundled_program.serialize_to_schema().method_test_suites[plan_id]
9299
)

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38"
9292

9393
# vela
9494
vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"
95-
vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7"
95+
vela_rev="5427dc7e9c1a4c7d554163290faeea75f168772d"
9696

9797
########
9898
### Mandatory user args

examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
5656
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5757
* Export Llama model and generate .pte file as below:
5858
```
59-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
59+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
6060
```
6161

6262
### For Llama 3.2 1B and 3B QAT+LoRA models
6363
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
6464
* Export Llama model and generate .pte file as below:
6565
```
66-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
66+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
6767
```
6868

6969
### For Llama 3.2 1B and 3B BF16 models
@@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
7272
* Export Llama model and generate .pte file as below:
7373

7474
```
75-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
75+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
7676
```
7777

7878
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
4848
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
4949
* Export Llama model and generate .pte file as below:
5050
```
51-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
51+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
5252
```
5353

5454
### For Llama 3.2 1B and 3B QAT+LoRA models
5555
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5656
* Export Llama model and generate .pte file as below:
5757
```
58-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
58+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
5959
```
6060

6161
### For Llama 3.2 1B and 3B BF16 models
@@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
6464
* Export Llama model and generate .pte file as below:
6565

6666
```
67-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
67+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
6868
```
6969

7070
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/models/llama/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
168168
LLAMA_PARAMS=path/to/params.json
169169
170170
python -m examples.models.llama.export_llama \
171+
--model "llama3_2" \
171172
--checkpoint "${LLAMA_CHECKPOINT:?}" \
172173
--params "${LLAMA_PARAMS:?}" \
173174
-kv \
@@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
189190
LLAMA_PARAMS=path/to/spinquant/params.json
190191
191192
python -m examples.models.llama.export_llama \
193+
--model "llama3_2" \
192194
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
193195
--params "${LLAMA_PARAMS:?}" \
194196
--use_sdpa_with_kv_cache \
@@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
214216
LLAMA_PARAMS=path/to/qlora/params.json
215217
216218
python -m examples.models.llama.export_llama \
219+
--model "llama3_2" \
217220
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
218221
--params "${LLAMA_PARAMS:?}" \
219222
-qat \

0 commit comments

Comments
 (0)