Skip to content

Commit d5cd4f3

Browse files
authored
Qualcomm AI Engine Direct - Static Decoder Runner Support 16bit KV IO (#13127)
### Summary - Support 16bit KV IO for runner. (Capable to run either 8bit or 16bit) - Adding README for script to run Qwen2.5 0.5B - Improving the PPL score for Qwen2.5 0.5B from 18->12. - Fixing BC CI bug. Sample Script `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 1 --artifact ./16bit_qwen_1024 --enable_masked_softmax --r3` #### Stats with QNN2.37.0 on SM8750 Accuracy: 12ppl (Align with prepare_pt2e and convert_pt2e) Token Rate: ~130tok/sec, depending on seq_len. <img width="1658" height="877" alt="image" src="https://github.com/user-attachments/assets/8fa19068-5613-4329-a527-52f3e02d408f" /> ### Test plan Added E2E test to `test_qnn_delegate.py`
1 parent 1976647 commit d5cd4f3

File tree

16 files changed

+287
-171
lines changed

16 files changed

+287
-171
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
158158

159159
def annotate_matmul_16a8w( # noqa: C901
160160
gm: torch.fx.GraphModule,
161-
annotate_conv=True,
162161
is_qat=False,
163162
) -> None:
164163
"""
@@ -337,10 +336,9 @@ def annotate_matmul_input1(node: Node, is_qat: str):
337336
# The arguments of cat op: (the past kv cache, the new kv cache)
338337
node = node.args[0][1]
339338
elif node.target == torch.ops.aten.conv2d.default:
340-
if annotate_conv:
341-
annotate_conv2d(
342-
node, quantization_config=quantization_config_8a4w_per_channel
343-
)
339+
annotate_conv2d(
340+
node, quantization_config=quantization_config_8a4w_per_channel
341+
)
344342
break
345343
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
346344
break

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4560,6 +4560,8 @@ def test_static_qwen2_5(self):
45604560
"wikitext",
45614561
"--limit",
45624562
"1",
4563+
"--r3",
4564+
"--enable_masked_softmax",
45634565
]
45644566
if self.compile_only:
45654567
cmds.extend(["--compile_only"])
@@ -4581,7 +4583,7 @@ def test_static_qwen2_5(self):
45814583
self.fail(msg["Error"])
45824584
else:
45834585
inference_speed_ref = {"SM8650": 110, "SM8750": 130}
4584-
self.assertLessEqual(msg["wiki_ppl"], 25)
4586+
self.assertLessEqual(msg["wiki_ppl"], 15)
45854587
self.assertLessEqual(msg["pte_size"], 800000000) # 800mb
45864588
if self.model in inference_speed_ref:
45874589
self.assertGreaterEqual(

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,19 @@ At the end of this step, users should have the following files ready: `consolida
5959
### Step3: Run default examples using hybrid mode.
6060
#### LLAMA2
6161
```bash
62-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
62+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
6363
```
6464

6565
#### LLAMA3.2
6666
Default example using hybrid mode.
6767
```bash
68-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
68+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
69+
```
70+
71+
#### QWEN2.5 0.5B
72+
Default example using hybrid mode
73+
```bash
74+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
6975
```
7076

7177
### KV Cache update mechanism
@@ -120,21 +126,21 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can
120126
#### Compile Only
121127
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
122128
```bash
123-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
129+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
124130
```
125131

126132
#### Pre Generated PTE
127133
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
128134
```bash
129-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
135+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
130136
```
131137

132138
#### KV Cache Updater
133139

134140
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
135141
`KV_UPDATER` = "shift_pointer"
136142
```bash
137-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
143+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
138144
```
139145

140146
#### Lookahead Decoding Mode
@@ -147,7 +153,7 @@ You can choose the lookahead mode to enhance decoding speed. To use this mode, y
147153
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)
148154

149155
```bash
150-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
156+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
151157
```
152158

153159
#### Masked Softmax

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def quantize(
264264

265265
self.llama_graph_module = convert_pt2e(fx_graph_module)
266266

267-
logging.info("Verifying the QDQ model...")
268267
if args.eval_perplexity:
268+
logging.info("Verifying the QDQ model...")
269269
# Check qdq cpu results
270270
graph_module_inference(
271271
args=args,
@@ -362,6 +362,7 @@ def compile(args, pte_filename, tokenizer):
362362
kv_config.use_kv_cache = True
363363
kv_config.enable_masked_softmax = args.enable_masked_softmax
364364
kv_config.enable_r3 = args.r3
365+
kv_config.kv_io_bit_width = 16 if args.ptq == "16a8w" else 8
365366

366367
prefill_config = copy.copy(kv_config)
367368
prefill_config.use_kv_cache = (
@@ -535,11 +536,15 @@ def permute(w, heads):
535536
fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32}
536537
if args.ptq:
537538
use_fp16 = False
538-
fixed_point_type["kv_type"] = torch.uint8
539539
if args.ptq == "8a8w":
540540
fixed_point_type["io_type"] = torch.uint8
541-
elif args.ptq in ("16a4w", "16a4w_block", "16a8w"):
541+
fixed_point_type["kv_type"] = torch.uint8
542+
elif args.ptq in ("16a4w", "16a4w_block"):
542543
fixed_point_type["io_type"] = torch.uint16
544+
fixed_point_type["kv_type"] = torch.uint8
545+
elif args.ptq == "16a8w":
546+
fixed_point_type["io_type"] = torch.uint16
547+
fixed_point_type["kv_type"] = torch.uint16
543548
else:
544549
assert args.ptq in [
545550
"8a8w",
@@ -572,13 +577,10 @@ def permute(w, heads):
572577

573578
if args.ptq:
574579
start_quantize_ts = time.time()
575-
custom_annotations = (
576-
# For qwen2.5, skip annotate_conv can improve result.
577-
partial(
578-
annotate_matmul_16a8w,
579-
annotate_conv=args.ptq != "16a8w",
580-
),
581-
)
580+
custom_annotations = ()
581+
if args.ptq != "16a8w":
582+
# 16a8w use 16bit kv io, so skip this custom annotation
583+
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
582584
if args.decoder_model in {"stories110m", "stories260k"}:
583585
custom_annotations = custom_annotations + (
584586
annotate_linear_16a8w_in_affine_layer,

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def __init__(
444444
self.output_new_cache_only = output_new_cache_only
445445
self.use_i64_token = use_i64_token
446446
self.output_cache = output_cache
447+
self.kv_io_bit_width = config.kv_io_bit_width
447448

448449
self.layers = nn.ModuleList(
449450
[
@@ -607,4 +608,5 @@ def get_metadata(self):
607608
"get_n_layers": self.n_layers,
608609
"get_vocab_size": self.vocab_size,
609610
"get_use_kv_cache": self.use_kv_cache,
611+
"get_kv_io_bit_width": self.kv_io_bit_width,
610612
}

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,16 @@ std::string get_formatted_prompt(
133133
return formatted_prompt;
134134
}
135135

136-
int main(int argc, char** argv) {
137-
std::vector<std::string> prompts = CollectPrompts(argc, argv);
138-
gflags::ParseCommandLineFlags(&argc, &argv, true);
139-
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
140-
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
141-
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
142-
}
143-
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
144-
FLAGS_eval_mode != 0) {
145-
ET_CHECK_MSG(
146-
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
147-
}
148-
136+
template <typename T>
137+
void start_runner(
138+
std::unique_ptr<executorch::extension::Module> module,
139+
std::vector<std::string>& prompts) {
149140
bool use_tokenized_prompt =
150141
gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false
151142
: true;
152143
// create llama runner
153-
example::Runner runner(
144+
example::Runner<T> runner(
145+
std::move(module),
154146
FLAGS_decoder_model_version.c_str(),
155147
FLAGS_model_path.c_str(),
156148
FLAGS_tokenizer_path.c_str(),
@@ -196,5 +188,43 @@ int main(int argc, char** argv) {
196188

197189
fout.write(buf.data(), buf.size());
198190
fout.close();
191+
}
192+
193+
int main(int argc, char** argv) {
194+
std::vector<std::string> prompts = CollectPrompts(argc, argv);
195+
gflags::ParseCommandLineFlags(&argc, &argv, true);
196+
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
197+
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
198+
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
199+
}
200+
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
201+
FLAGS_eval_mode != 0) {
202+
ET_CHECK_MSG(
203+
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
204+
}
205+
206+
std::unique_ptr<executorch::extension::Module> module =
207+
std::make_unique<executorch::extension::Module>(
208+
FLAGS_model_path.c_str(),
209+
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
210+
// Using 8bit as default since this meta is introduced with 16bit kv io
211+
// support and older models only have 8bit kv io.
212+
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
213+
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
214+
kv_bitwidth = static_cast<example::KvBitWidth>(
215+
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
216+
}
217+
218+
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
219+
start_runner<uint8_t>(std::move(module), prompts);
220+
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
221+
start_runner<uint16_t>(std::move(module), prompts);
222+
} else {
223+
ET_CHECK_MSG(
224+
false,
225+
"Unsupported kv bitwidth: %ld",
226+
static_cast<int64_t>(kv_bitwidth));
227+
}
228+
199229
return 0;
200230
}

0 commit comments

Comments
 (0)