Skip to content

Commit f9efbc5

Browse files
committed
Qualcomm AI Engine Direct - Static Decoder Runner Support 16bit KV IO
1 parent 07b6059 commit f9efbc5

File tree

16 files changed

+289
-174
lines changed

16 files changed

+289
-174
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
153153
)
154154

155155

156-
def annotate_matmul_16a8w( # noqa: C901
157-
gm: torch.fx.GraphModule, annotate_conv=True
158-
) -> None:
156+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
159157
"""
160158
This function is specific for matmul op 16a8w.
161159
For k, we will tag such as the below, and
@@ -319,10 +317,9 @@ def annotate_matmul_input1(node: Node):
319317
# The arguments of cat op: (the past kv cache, the new kv cache)
320318
node = node.args[0][1]
321319
elif node.target == torch.ops.aten.conv2d.default:
322-
if annotate_conv:
323-
annotate_conv2d(
324-
node, quantization_config=quantization_config_8a4w_per_channel
325-
)
320+
annotate_conv2d(
321+
node, quantization_config=quantization_config_8a4w_per_channel
322+
)
326323
break
327324
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
328325
break

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4360,6 +4360,8 @@ def test_static_qwen2_5(self):
43604360
"wikitext",
43614361
"--limit",
43624362
"1",
4363+
"--r3",
4364+
"--enable_masked_softmax",
43634365
]
43644366
if self.compile_only:
43654367
cmds.extend(["--compile_only"])
@@ -4381,7 +4383,7 @@ def test_static_qwen2_5(self):
43814383
self.fail(msg["Error"])
43824384
else:
43834385
inference_speed_ref = {"SM8650": 110, "SM8750": 130}
4384-
self.assertLessEqual(msg["wiki_ppl"], 25)
4386+
self.assertLessEqual(msg["wiki_ppl"], 15)
43854387
self.assertLessEqual(msg["pte_size"], 800000000) # 800mb
43864388
if self.model in inference_speed_ref:
43874389
self.assertGreaterEqual(

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,19 @@ At the end of this step, users should have the following files ready: `consolida
5757
### Step3: Run default examples using hybrid mode.
5858
#### LLAMA2
5959
```bash
60-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
60+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
6161
```
6262

6363
#### LLAMA3.2
6464
Default example using hybrid mode.
6565
```bash
66-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
66+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
67+
```
68+
69+
#### QWEN2.5 0.5B
70+
Default example using hybrid mode
71+
```bash
72+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
6773
```
6874

6975
### KV Cache update mechanism
@@ -118,21 +124,21 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can
118124
#### Compile Only
119125
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
120126
```bash
121-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
127+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
122128
```
123129

124130
#### Pre Generated PTE
125131
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
126132
```bash
127-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
133+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
128134
```
129135

130136
#### KV Cache Updater
131137

132138
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
133139
`KV_UPDATER` = "shift_pointer"
134140
```bash
135-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
141+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
136142
```
137143

138144
#### Lookahead Decoding Mode
@@ -145,7 +151,7 @@ You can choose the lookahead mode to enhance decoding speed. To use this mode, y
145151
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)
146152

147153
```bash
148-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
154+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
149155
```
150156

151157
#### Masked Softmax

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def quantize(
265265

266266
self.llama_graph_module = convert_pt2e(fx_graph_module)
267267

268-
logging.info("Verifying the QDQ model...")
269268
if args.eval_perplexity:
269+
logging.info("Verifying the QDQ model...")
270270
# Check qdq cpu results
271271
graph_module_inference(
272272
args=args,
@@ -375,6 +375,7 @@ def compile(args, pte_filename, tokenizer):
375375
kv_config.use_kv_cache = True
376376
kv_config.enable_masked_softmax = args.enable_masked_softmax
377377
kv_config.enable_r3 = args.r3
378+
kv_config.kv_io_bit_width = 16 if args.ptq == "16a8w" else 8
378379

379380
prefill_config = copy.copy(kv_config)
380381
prefill_config.use_kv_cache = (
@@ -551,11 +552,15 @@ def permute(w, heads):
551552
fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32}
552553
if args.ptq:
553554
use_fp16 = False
554-
fixed_point_type["kv_type"] = torch.uint8
555555
if args.ptq == "8a8w":
556556
fixed_point_type["io_type"] = torch.uint8
557-
elif args.ptq in ("16a4w", "16a4w_block", "16a8w"):
557+
fixed_point_type["kv_type"] = torch.uint8
558+
elif args.ptq in ("16a4w", "16a4w_block"):
558559
fixed_point_type["io_type"] = torch.uint16
560+
fixed_point_type["kv_type"] = torch.uint8
561+
elif args.ptq == "16a8w":
562+
fixed_point_type["io_type"] = torch.uint16
563+
fixed_point_type["kv_type"] = torch.uint16
559564
else:
560565
assert args.ptq in [
561566
"8a8w",
@@ -588,14 +593,11 @@ def permute(w, heads):
588593

589594
if args.ptq:
590595
start_quantize_ts = time.time()
591-
custom_annotations = (
592-
# For qwen2.5, skip annotate_conv can improve result.
593-
partial(
594-
annotate_matmul_16a8w,
595-
annotate_conv=args.ptq != "16a8w",
596-
),
597-
)
598-
if args.decoder_model == {"stories110m", "stories260k"}:
596+
custom_annotations = ()
597+
if args.ptq != "16a8w":
598+
# 16a8w use 16bit kv io, so skip this custom annotation
599+
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
600+
if args.decoder_model in {"stories110m", "stories260k"}:
599601
custom_annotations = custom_annotations + (
600602
annotate_linear_16a8w_in_affine_layer,
601603
)

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def __init__(
393393
self.output_new_cache_only = output_new_cache_only
394394
self.use_i64_token = use_i64_token
395395
self.output_cache = output_cache
396+
self.kv_io_bit_width = config.kv_io_bit_width
396397

397398
self.layers = nn.ModuleList(
398399
[
@@ -546,4 +547,5 @@ def get_metadata(self):
546547
"get_n_layers": self.n_layers,
547548
"get_vocab_size": self.vocab_size,
548549
"get_use_kv_cache": self.use_kv_cache,
550+
"get_kv_io_bit_width": self.kv_io_bit_width,
549551
}

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,24 +123,16 @@ std::string get_formatted_prompt(
123123
return formatted_prompt;
124124
}
125125

126-
int main(int argc, char** argv) {
127-
std::vector<std::string> prompts = CollectPrompts(argc, argv);
128-
gflags::ParseCommandLineFlags(&argc, &argv, true);
129-
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
130-
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
131-
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
132-
}
133-
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
134-
FLAGS_eval_mode != 0) {
135-
ET_CHECK_MSG(
136-
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
137-
}
138-
126+
template <typename T>
127+
void start_runner(
128+
std::unique_ptr<executorch::extension::Module> module,
129+
std::vector<std::string>& prompts) {
139130
bool use_tokenized_prompt =
140131
gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false
141132
: true;
142133
// create llama runner
143-
example::Runner runner(
134+
example::Runner<T> runner(
135+
std::move(module),
144136
FLAGS_decoder_model_version.c_str(),
145137
FLAGS_model_path.c_str(),
146138
FLAGS_tokenizer_path.c_str(),
@@ -186,5 +178,43 @@ int main(int argc, char** argv) {
186178

187179
fout.write(buf.data(), buf.size());
188180
fout.close();
181+
}
182+
183+
int main(int argc, char** argv) {
184+
std::vector<std::string> prompts = CollectPrompts(argc, argv);
185+
gflags::ParseCommandLineFlags(&argc, &argv, true);
186+
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
187+
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
188+
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
189+
}
190+
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
191+
FLAGS_eval_mode != 0) {
192+
ET_CHECK_MSG(
193+
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
194+
}
195+
196+
std::unique_ptr<executorch::extension::Module> module =
197+
std::make_unique<executorch::extension::Module>(
198+
FLAGS_model_path.c_str(),
199+
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
200+
// Using 8bit as default since this meta is introduced with 16bit kv io
201+
// support and older models only have 8bit kv io.
202+
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
203+
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
204+
kv_bitwidth = static_cast<example::KvBitWidth>(
205+
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
206+
}
207+
208+
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
209+
start_runner<uint8_t>(std::move(module), prompts);
210+
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
211+
start_runner<uint16_t>(std::move(module), prompts);
212+
} else {
213+
ET_CHECK_MSG(
214+
false,
215+
"Unsupported kv bitwidth: %ld",
216+
static_cast<int64_t>(kv_bitwidth));
217+
}
218+
189219
return 0;
190220
}

0 commit comments

Comments
 (0)