Skip to content

Commit 98227fd

Browse files
committed
Update base for Update on "[ET-VK] Add PushConstantDataInfo and vector to hold push constants data in DispatchNode."
This diff adds a new class called `PushConstantDataInfo` to the `DispatchNode` class in the Vulkan backend for Executorch. This class represents a push constant data entry, which can either be a shared pointer to a tensor's uniform data with an attribute or data with a maximum size of 16 bytes. The `write` method is also added to this class, which writes the data to a destination buffer. Differential Revision: [D66796049](https://our.internmc.facebook.com/intern/diff/D66796049/) [ghstack-poisoned]
2 parents 59f6a27 + 957259e commit 98227fd

File tree

9 files changed

+37
-17
lines changed

9 files changed

+37
-17
lines changed

backends/arm/test/ops/test_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def test_layer_norm_tosa_BI(
157157

158158
# Numerical issues on FVP likely due to mul op, MLETORCH-521
159159
# Skip tests that require transposes.
160-
@parameterized.expand(test_data_suite[:-2])
160+
@parameterized.expand(test_data_suite)
161161
@unittest.expectedFailure
162-
def test_layer_norm_u55_BI(
162+
def test_layer_norm_u55_BI_xfails(
163163
self,
164164
test_name: str,
165165
test_data: torch.Tensor,
@@ -171,7 +171,8 @@ def test_layer_norm_u55_BI(
171171

172172
# Numerical issues on FVP likely due to mul op, MLETORCH-521
173173
@parameterized.expand(test_data_suite[:-2])
174-
def test_layer_norm_u85_BI_fvp(
174+
@unittest.expectedFailure
175+
def test_layer_norm_u85_BI_xfails(
175176
self,
176177
test_name: str,
177178
test_data: torch.Tensor,
@@ -182,7 +183,6 @@ def test_layer_norm_u85_BI_fvp(
182183
)
183184

184185
@parameterized.expand(test_data_suite[-2:])
185-
@unittest.skip # Flaky
186186
def test_layer_norm_u85_BI(
187187
self,
188188
test_name: str,

backends/vulkan/docs/android_demo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
5959
# The files will usually be downloaded to ~/.llama
6060
python -m examples.models.llama.export_llama \
6161
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
62+
--model "llama3_2" \
6263
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
6364
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
6465
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38"
9292

9393
# vela
9494
vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"
95-
vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7"
95+
vela_rev="5427dc7e9c1a4c7d554163290faeea75f168772d"
9696

9797
########
9898
### Mandatory user args

examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
5656
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5757
* Export Llama model and generate .pte file as below:
5858
```
59-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
59+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
6060
```
6161

6262
### For Llama 3.2 1B and 3B QAT+LoRA models
6363
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
6464
* Export Llama model and generate .pte file as below:
6565
```
66-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
66+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
6767
```
6868

6969
### For Llama 3.2 1B and 3B BF16 models
@@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
7272
* Export Llama model and generate .pte file as below:
7373

7474
```
75-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
75+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
7676
```
7777

7878
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
4848
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
4949
* Export Llama model and generate .pte file as below:
5050
```
51-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
51+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
5252
```
5353

5454
### For Llama 3.2 1B and 3B QAT+LoRA models
5555
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5656
* Export Llama model and generate .pte file as below:
5757
```
58-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
58+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
5959
```
6060

6161
### For Llama 3.2 1B and 3B BF16 models
@@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
6464
* Export Llama model and generate .pte file as below:
6565

6666
```
67-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
67+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
6868
```
6969

7070
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/models/llama/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
168168
LLAMA_PARAMS=path/to/params.json
169169
170170
python -m examples.models.llama.export_llama \
171+
--model "llama3_2" \
171172
--checkpoint "${LLAMA_CHECKPOINT:?}" \
172173
--params "${LLAMA_PARAMS:?}" \
173174
-kv \
@@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
189190
LLAMA_PARAMS=path/to/spinquant/params.json
190191
191192
python -m examples.models.llama.export_llama \
193+
--model "llama3_2" \
192194
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
193195
--params "${LLAMA_PARAMS:?}" \
194196
--use_sdpa_with_kv_cache \
@@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
214216
LLAMA_PARAMS=path/to/qlora/params.json
215217
216218
python -m examples.models.llama.export_llama \
219+
--model "llama3_2" \
217220
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
218221
--params "${LLAMA_PARAMS:?}" \
219222
-qat \

examples/models/llama/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class ModelArgs:
113113
)
114114
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
115115
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
116+
rope_scale_factor: int = 8
116117
# Additional Model Metadata needed at runtime
117118
bos_idx: int = 1
118119
eos_idx: int = 3
@@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs):
155156
self.precompute_freqs_cis = hf_precompute_freqs_cis
156157
else:
157158
self.precompute_freqs_cis = partial(
158-
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
159+
precompute_freqs_cis,
160+
use_scaled=self.params.use_scaled_rope,
161+
scale_factor=self.params.rope_scale_factor,
159162
)
160163
freqs_cos, freqs_sin = self.precompute_freqs_cis(
161164
self.params.head_dim,

examples/models/llama/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def __init__(self, **kwargs):
145145
enable_dynamic_shape=self.enable_dynamic_shape,
146146
**params,
147147
)
148+
149+
if model_args.use_scaled_rope:
150+
# Older models don't have use_scaled_rope configuration
151+
assert self.args.model not in ["llama2", "stories110m"]
152+
153+
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
154+
if self.args.model not in ["llama3", "llama3_1"]:
155+
model_args.rope_scale_factor = 32
156+
148157
if kwargs.get("verbose", False):
149158
print("============= weights ================")
150159
print("{key} : {weights.numel()} : {weights.size()}")

examples/models/llama/rope.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
# Different RoPE implementations
99

1010
import math
11-
from typing import Tuple
11+
from typing import Optional, Tuple
1212

1313
import torch
1414

1515
# ======================== Stock Implementation ========================
1616

1717

18-
def apply_scaling(freqs: torch.Tensor):
18+
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
1919
# Values obtained from grid search
20-
scale_factor = 8
2120
low_freq_factor = 1
2221
high_freq_factor = 4
2322
old_context_len = 8192 # original llama3 length
@@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor):
4140

4241

4342
def precompute_freqs_cis(
44-
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
43+
dim: int,
44+
end: int,
45+
theta: float = 10000.0,
46+
use_scaled: bool = False,
47+
scale_factor: Optional[int] = None,
4548
):
4649
freqs = 1.0 / (
4750
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
4851
)
4952
t = torch.arange(end, device=freqs.device) # pyre-ignore
5053
if use_scaled:
51-
freqs = apply_scaling(freqs) # pyre-ignore
54+
assert scale_factor is not None
55+
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
5256
freqs = torch.outer(t, freqs).float()
5357
freqs_cos = torch.cos(freqs)
5458
freqs_sin = torch.sin(freqs)

0 commit comments

Comments
 (0)