Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):

def annotate_matmul_16a8w( # noqa: C901
gm: torch.fx.GraphModule,
annotate_conv=True,
is_qat=False,
) -> None:
"""
Expand Down Expand Up @@ -337,10 +336,9 @@ def annotate_matmul_input1(node: Node, is_qat: str):
# The arguments of cat op: (the past kv cache, the new kv cache)
node = node.args[0][1]
elif node.target == torch.ops.aten.conv2d.default:
if annotate_conv:
annotate_conv2d(
node, quantization_config=quantization_config_8a4w_per_channel
)
annotate_conv2d(
node, quantization_config=quantization_config_8a4w_per_channel
)
break
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
break
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4560,6 +4560,8 @@ def test_static_qwen2_5(self):
"wikitext",
"--limit",
"1",
"--r3",
"--enable_masked_softmax",
]
if self.compile_only:
cmds.extend(["--compile_only"])
Expand All @@ -4581,7 +4583,7 @@ def test_static_qwen2_5(self):
self.fail(msg["Error"])
else:
inference_speed_ref = {"SM8650": 110, "SM8750": 130}
self.assertLessEqual(msg["wiki_ppl"], 25)
self.assertLessEqual(msg["wiki_ppl"], 15)
self.assertLessEqual(msg["pte_size"], 800000000) # 800mb
if self.model in inference_speed_ref:
self.assertGreaterEqual(
Expand Down
18 changes: 12 additions & 6 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,19 @@ At the end of this step, users should have the following files ready: `consolida
### Step3: Run default examples using hybrid mode.
#### LLAMA2
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
```

#### LLAMA3.2
Default example using hybrid mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
```

#### QWEN2.5 0.5B
Default example using hybrid mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
```

### KV Cache update mechanism
Expand Down Expand Up @@ -120,21 +126,21 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can
#### Compile Only
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
```

#### Pre Generated PTE
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
```

#### KV Cache Updater

You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
`KV_UPDATER` = "shift_pointer"
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
```

#### Lookahead Decoding Mode
Expand All @@ -147,7 +153,7 @@ You can choose the lookahead mode to enhance decoding speed. To use this mode, y
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)

```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
```

#### Masked Softmax
Expand Down
22 changes: 12 additions & 10 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def quantize(

self.llama_graph_module = convert_pt2e(fx_graph_module)

logging.info("Verifying the QDQ model...")
if args.eval_perplexity:
logging.info("Verifying the QDQ model...")
# Check qdq cpu results
graph_module_inference(
args=args,
Expand Down Expand Up @@ -362,6 +362,7 @@ def compile(args, pte_filename, tokenizer):
kv_config.use_kv_cache = True
kv_config.enable_masked_softmax = args.enable_masked_softmax
kv_config.enable_r3 = args.r3
kv_config.kv_io_bit_width = 16 if args.ptq == "16a8w" else 8

prefill_config = copy.copy(kv_config)
prefill_config.use_kv_cache = (
Expand Down Expand Up @@ -535,11 +536,15 @@ def permute(w, heads):
fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32}
if args.ptq:
use_fp16 = False
fixed_point_type["kv_type"] = torch.uint8
if args.ptq == "8a8w":
fixed_point_type["io_type"] = torch.uint8
elif args.ptq in ("16a4w", "16a4w_block", "16a8w"):
fixed_point_type["kv_type"] = torch.uint8
elif args.ptq in ("16a4w", "16a4w_block"):
fixed_point_type["io_type"] = torch.uint16
fixed_point_type["kv_type"] = torch.uint8
elif args.ptq == "16a8w":
fixed_point_type["io_type"] = torch.uint16
fixed_point_type["kv_type"] = torch.uint16
else:
assert args.ptq in [
"8a8w",
Expand Down Expand Up @@ -572,13 +577,10 @@ def permute(w, heads):

if args.ptq:
start_quantize_ts = time.time()
custom_annotations = (
# For qwen2.5, skip annotate_conv can improve result.
partial(
annotate_matmul_16a8w,
annotate_conv=args.ptq != "16a8w",
),
)
custom_annotations = ()
if args.ptq != "16a8w":
# 16a8w use 16bit kv io, so skip this custom annotation
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
if args.decoder_model in {"stories110m", "stories260k"}:
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
Expand Down
2 changes: 2 additions & 0 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(
self.output_new_cache_only = output_new_cache_only
self.use_i64_token = use_i64_token
self.output_cache = output_cache
self.kv_io_bit_width = config.kv_io_bit_width

self.layers = nn.ModuleList(
[
Expand Down Expand Up @@ -607,4 +608,5 @@ def get_metadata(self):
"get_n_layers": self.n_layers,
"get_vocab_size": self.vocab_size,
"get_use_kv_cache": self.use_kv_cache,
"get_kv_io_bit_width": self.kv_io_bit_width,
}
58 changes: 44 additions & 14 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,16 @@ std::string get_formatted_prompt(
return formatted_prompt;
}

int main(int argc, char** argv) {
std::vector<std::string> prompts = CollectPrompts(argc, argv);
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
}
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
FLAGS_eval_mode != 0) {
ET_CHECK_MSG(
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
}

template <typename T>
void start_runner(
std::unique_ptr<executorch::extension::Module> module,
std::vector<std::string>& prompts) {
bool use_tokenized_prompt =
gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false
: true;
// create llama runner
example::Runner runner(
example::Runner<T> runner(
std::move(module),
FLAGS_decoder_model_version.c_str(),
FLAGS_model_path.c_str(),
FLAGS_tokenizer_path.c_str(),
Expand Down Expand Up @@ -196,5 +188,43 @@ int main(int argc, char** argv) {

fout.write(buf.data(), buf.size());
fout.close();
}

int main(int argc, char** argv) {
std::vector<std::string> prompts = CollectPrompts(argc, argv);
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default &&
!gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) {
ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both.");
}
if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default &&
FLAGS_eval_mode != 0) {
ET_CHECK_MSG(
false, "Only TokenGenerator(kv) mode is supported to dump all logits.");
}

std::unique_ptr<executorch::extension::Module> module =
std::make_unique<executorch::extension::Module>(
FLAGS_model_path.c_str(),
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
// Using 8bit as default since this meta is introduced with 16bit kv io
// support and older models only have 8bit kv io.
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
}

if (kv_bitwidth == example::KvBitWidth::kWidth8) {
start_runner<uint8_t>(std::move(module), prompts);
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
start_runner<uint16_t>(std::move(module), prompts);
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}

return 0;
}
Loading
Loading