Skip to content

Commit db2c555

Browse files
committed
Enable x86 runner for static llama
1 parent 526a0d8 commit db2c555

File tree

6 files changed

+143
-54
lines changed

6 files changed

+143
-54
lines changed

.github/workflows/pull.yml

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ jobs:
440440
# Test llama2
441441
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}"
442442
443-
test-static-llama-runner-qnn-linux:
443+
# Compile only as weight sharing is not applicable on x86
444+
test-static-llama-size-qnn-linux:
444445
name: test-static-llama-runner-qnn-linux
445446
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
446447
strategy:
@@ -459,13 +460,46 @@ jobs:
459460
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
460461
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
461462
463+
# Setup executorch
464+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}"
465+
462466
# Retrieve 110M Stories Llama Artifacts
463-
PYTHON_EXECUTABLE=python bash .ci/scripts/utils.sh
464467
PYTHON_EXECUTABLE=python download_stories_model_artifacts
468+
$PYTHON_EXECUTABLE -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
465469
466-
# Test static llama stories110m
470+
# Test static llama stories110m pte size
467471
PYTHON_EXECUTABLE=python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleScript.test_stories_single_llama --model SM8650 --build_folder build-android/ --executorch_root . --artifact_dir . --compile_only"
468472
473+
# Checks accuracy with weight sharing disabled since x86 does not support weight sharing.
474+
test-static-llama-accuracy-qnn-linux:
475+
name: test-static-llama-runner-qnn-linux
476+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
477+
strategy:
478+
fail-fast: false
479+
with:
480+
runner: linux.2xlarge
481+
docker-image: executorch-ubuntu-22.04-qnn-sdk
482+
submodules: 'true'
483+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
484+
timeout: 900
485+
script: |
486+
# The generic Linux job chooses to use base env, not the one setup by the image
487+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
488+
conda activate "${CONDA_ENV}"
489+
490+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
491+
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
492+
493+
# Setup executorch
494+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}"
495+
496+
# Retrieve 110M Stories Llama Artifacts
497+
PYTHON_EXECUTABLE=python download_stories_model_artifacts
498+
$PYTHON_EXECUTABLE -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
499+
500+
# Test static llama stories110m accuracy
501+
PYTHON_EXECUTABLE=python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleScript.test_stories_single_llama --model SM8650 --build_folder build-x86_64/ --executorch_root . --artifact_dir . --enable_x86_64"
502+
469503
test-qnn-models-linux:
470504
name: test-qnn-models-linux
471505
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,7 @@ def test_qnn_backend_multi_graphs(self):
19301930
soc_model=self.chipset_table[TestQNN.model],
19311931
backend_options=backend_options,
19321932
multiple_graphs=True,
1933+
weight_sharing=True,
19331934
graph_name=graph_name,
19341935
)
19351936
for graph_name in graph_names
@@ -2418,6 +2419,7 @@ def test_qnn_backend_multi_graphs(self):
24182419
soc_model=self.chipset_table[TestQNN.model],
24192420
backend_options=backend_options,
24202421
multiple_graphs=True,
2422+
weight_sharing=True,
24212423
graph_name=graph_name,
24222424
)
24232425
for graph_name in graph_names
@@ -3621,6 +3623,8 @@ def test_stories_single_llama(self):
36213623
cmds.extend(["--device", self.device])
36223624
if self.host:
36233625
cmds.extend(["--host", self.host])
3626+
if self.enable_x86_64:
3627+
cmds.extend(["--enable_x86_64"])
36243628

36253629
golden_start_with = "Once upon a time,"
36263630
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
@@ -3634,8 +3638,10 @@ def test_stories_single_llama(self):
36343638
if not self.compile_only:
36353639
model_out = msg["result"][0]
36363640
self.assertTrue(model_out.startswith(golden_start_with))
3637-
pte_size = msg["pte_size"]
3638-
self.assertLessEqual(pte_size, 130000000)
3641+
# x86 does not allow weight sharing, so we don't check pte size
3642+
if not self.enable_x86_64:
3643+
pte_size = msg["pte_size"]
3644+
self.assertLessEqual(pte_size, 130000000)
36393645

36403646
@unittest.skip("dynamic shape inputs appear in recent torch.export.export")
36413647
def test_mobilebert(self):
@@ -3840,12 +3846,6 @@ def setup_environment():
38403846
help="Path to open source software model repository",
38413847
type=str,
38423848
)
3843-
parser.add_argument(
3844-
"-x",
3845-
"--enable_x86_64",
3846-
help="Enable unittest to be executed on x86_64 platform",
3847-
action="store_true",
3848-
)
38493849

38503850
args, ns_args = parser.parse_known_args(namespace=unittest)
38513851
TestQNN.host = args.host

backends/qualcomm/utils/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,7 @@ def generate_qnn_executorch_compiler_spec(
10471047
shared_buffer: bool = False,
10481048
is_from_context_binary: bool = False,
10491049
multiple_graphs: bool = False,
1050+
weight_sharing: bool = False,
10501051
graph_name: str = "forward",
10511052
) -> List[CompileSpec]:
10521053
"""
@@ -1077,6 +1078,7 @@ def generate_qnn_executorch_compiler_spec(
10771078
is_from_context_binary: True if current graph comes from pre-built context binary.
10781079
multiple_graphs: True if multiple methods are expected to have in single .pte file.
10791080
Please see test cases for post-processing example.
1081+
weight_sharing: Used with multiple_graphs, where model size will be reduced when operations have the same weights across multiple graphs.
10801082
graph_name: Assign unique graph name if 'multiple_graphs' is used.
10811083
10821084
Returns:
@@ -1097,6 +1099,12 @@ def generate_qnn_executorch_compiler_spec(
10971099
stacklevel=1,
10981100
)
10991101

1102+
if weight_sharing and not multiple_graphs:
1103+
warnings.warn(
1104+
"Weight sharing is intended for multiple graphs scenario, please ensure if there are multiple graphs",
1105+
stacklevel=1,
1106+
)
1107+
11001108
qnn_executorch_options = QnnExecuTorchOptions(
11011109
_soc_info_table[soc_model], backend_options
11021110
)
@@ -1138,7 +1146,10 @@ def generate_qnn_executorch_compiler_spec(
11381146

11391147
if multiple_graphs:
11401148
# enable weight sharing mechanism if multiple graphs appear
1141-
if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend:
1149+
if (
1150+
backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend
1151+
and weight_sharing
1152+
):
11421153
backend_options.htp_options.use_weight_sharing = True
11431154

11441155
return [

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import json
1313
import logging
1414
import os
15+
import subprocess
1516
import sys
1617
import time
1718
from functools import partial
@@ -594,6 +595,9 @@ def compile(args, pte_filename, tokenizer):
594595
backend_options=backend_options,
595596
shared_buffer=args.shared_buffer,
596597
multiple_graphs=True,
598+
weight_sharing=(
599+
False if args.enable_x86_64 else True
600+
), # x86 emulator does not support weight sharing
597601
graph_name=graph_name,
598602
)
599603
for graph_name in graph_names
@@ -751,48 +755,11 @@ def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_p
751755
else:
752756
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
753757

754-
seq_len = args.prefill_seq_len if args.model_mode == "prefill" else args.kv_seq_len
755-
runner_args = " ".join(
756-
[
757-
f"--model_path {pte_filename}.pte",
758-
"--output_path outputs/outputs.txt",
759-
f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}",
760-
f'--prompt "{args.prompt}"',
761-
f"--seq_len {seq_len}",
762-
f"--eval_mode {eval_mode}",
763-
f"--temperature {args.temperature}",
764-
f"--system_prompt '{args.system_prompt}'",
765-
f"--logits_scale {quant_attrs['scale']}",
766-
f"--logits_offset {quant_attrs['zero_point']}",
767-
f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
768-
]
769-
)
770-
runner_cmd = " ".join(
771-
[
772-
f"cd {workspace} &&",
773-
f"./qnn_llama_runner {runner_args}",
774-
]
775-
)
776-
777758
pte_path = (
778759
f"{pre_gen_pte}/{pte_filename}.pte"
779760
if pre_gen_pte
780761
else f"{args.artifact}/{pte_filename}.pte"
781762
)
782-
adb = SimpleADB(
783-
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
784-
build_path=f"{args.build_folder}",
785-
pte_path=pte_path,
786-
workspace=workspace,
787-
device_id=args.device,
788-
host_id=args.host,
789-
soc_model=args.model,
790-
shared_buffer=args.shared_buffer,
791-
runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner",
792-
)
793-
# No pregen inputs, input_list is not required
794-
adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path])
795-
adb.execute(custom_runner_cmd=runner_cmd)
796763

797764
# collect output data
798765
output_data_folder = f"{args.artifact}/outputs"
@@ -803,7 +770,79 @@ def post_process():
803770
with open(f"{args.artifact}/outputs/outputs.txt", "r") as f:
804771
outputs.append(f.read())
805772

806-
adb.pull(output_path=args.artifact, callback=post_process)
773+
seq_len = args.prefill_seq_len if args.model_mode == "prefill" else args.kv_seq_len
774+
runner_args = " ".join(
775+
[
776+
f'--prompt "{args.prompt}"',
777+
f"--eval_mode {eval_mode}",
778+
f"--temperature {args.temperature}",
779+
f"--system_prompt '{args.system_prompt}'",
780+
f"--logits_scale {quant_attrs['scale']}",
781+
f"--logits_offset {quant_attrs['zero_point']}",
782+
]
783+
)
784+
785+
runner_cmd = ""
786+
if args.enable_x86_64:
787+
# x86 emulator is intended for CI and not performance. Check only the first few tokens.
788+
seq_len = min(seq_len, 16)
789+
790+
if args.kv_updator == smart_mask_updator:
791+
logging.warning(
792+
"x86 only support ShiftPointer, overwrite kv_updator to ShiftPointer"
793+
)
794+
795+
qnn_sdk = os.getenv("QNN_SDK_ROOT")
796+
target = "x86_64-linux-clang"
797+
runner_cmd = " ".join(
798+
[
799+
f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&",
800+
f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner",
801+
f"--tokenizer_path {runtime_tokenizer_path}",
802+
f"--model_path {pte_path}",
803+
f"--seq_len {seq_len}",
804+
f"--output_path {args.artifact}/outputs/outputs.txt",
805+
f"--kv_updator ShiftPointer",
806+
runner_args,
807+
]
808+
)
809+
subprocess.run(
810+
runner_cmd,
811+
shell=True,
812+
executable="/bin/bash",
813+
capture_output=True,
814+
)
815+
post_process()
816+
else:
817+
runner_cmd = " ".join(
818+
[
819+
f"cd {workspace} &&",
820+
f"./qnn_llama_runner",
821+
f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}",
822+
f"--model_path {pte_filename}.pte",
823+
f"--seq_len {seq_len}",
824+
"--output_path outputs/outputs.txt",
825+
f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
826+
runner_args,
827+
]
828+
)
829+
830+
adb = SimpleADB(
831+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
832+
build_path=f"{args.build_folder}",
833+
pte_path=pte_path,
834+
workspace=workspace,
835+
device_id=args.device,
836+
host_id=args.host,
837+
soc_model=args.model,
838+
shared_buffer=args.shared_buffer,
839+
runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner",
840+
)
841+
# No pregen inputs, input_list is not required
842+
adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path])
843+
adb.execute(custom_runner_cmd=runner_cmd)
844+
845+
adb.pull(output_path=args.artifact, callback=post_process)
807846
if args.ip and args.port != -1:
808847
pte_size = os.path.getsize(pte_path)
809848
with Client((args.ip, args.port)) as conn:

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15-
from executorch.examples.models.llama.llama_transformer import (
16-
ModelArgs,
17-
precompute_freqs_cis,
18-
)
15+
from executorch.examples.models.llama.llama_transformer import ModelArgs
16+
from executorch.examples.models.llama.rope import precompute_freqs_cis
1917

2018

2119
def apply_rotary_emb_single(

examples/qualcomm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,13 @@ def setup_common_args_and_variables():
524524
default=False,
525525
)
526526

527+
parser.add_argument(
528+
"-x",
529+
"--enable_x86_64",
530+
help="Enable unittest to be executed on x86_64 platform",
531+
action="store_true",
532+
)
533+
527534
# QNN_SDK_ROOT might also be an argument, but it is used in various places.
528535
# So maybe it's fine to just use the environment.
529536
if "QNN_SDK_ROOT" not in os.environ:

0 commit comments

Comments
 (0)