Skip to content

Commit 65bfeca

Browse files
authored
Merge branch 'main' into use-quantize_
2 parents fa4c120 + 057558f commit 65bfeca

30 files changed

+1365
-111
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
fi
4747
4848
# This has already been cached in the docker image
49-
lintrunner init 2> /dev/null
49+
lintrunner init
5050
5151
RC=0
5252
# Run lintrunner on all files

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class I64toI32(ExportPass):
2828
I64_OPS = {
2929
exir_ops.edge.aten.argmin.default,
3030
exir_ops.edge.aten.arange.start_step,
31+
exir_ops.edge.aten.cumsum.default,
3132
exir_ops.edge.aten.full.default,
3233
exir_ops.edge.aten.scalar_tensor.default,
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3335
}
3436
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
3537
# For example, scatter op can only accept args[2], the index, as int64.

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def _build_tensor_constant(
8686
dtype=(
8787
node.args[0].meta["val"].dtype
8888
if not is_float_tensor(node)
89-
and not SCALAR_OPS.get(node.target).use_self_dtype
89+
and (info := SCALAR_OPS.get(node.target))
90+
and not info.use_self_dtype
9091
else node.meta["val"].dtype
9192
),
9293
device=node.meta["val"].device,

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
arg_list[index] = torch.finfo(torch.float32).min
3131
elif arg == float("inf"):
3232
arg_list[index] = torch.finfo(torch.float32).max
33+
34+
if node.target == torch.ops.aten.masked_fill.Scalar:
35+
if arg_list[2] == torch.finfo(torch.float32).max:
36+
arg_list[2] = 255
37+
elif arg_list[2] == torch.finfo(torch.float32).min:
38+
arg_list[2] = -255
3339
node.args = tuple(arg_list)
3440

3541
graph_module.recompile()

backends/qualcomm/builders/op_cum_sum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def define_node(
5151
dim = self.get_param(node, input_tensor)
5252

5353
output_tensor = self.get_tensor(node, node)
54+
if output_tensor.dtype == torch.int64:
55+
output_tensor = output_tensor.to(torch.int32)
5456
output_tensor_wrapper = self.define_tensor(
5557
node,
5658
node,

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def forward(self, x):
11011101
return torch.mean(x, (-1, -2))
11021102

11031103

1104+
class MaskedFill(torch.nn.Module):
1105+
def __init__(self):
1106+
super().__init__()
1107+
1108+
def forward(self, attn_mask):
1109+
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1110+
attn_mask == 0, float(0.0)
1111+
)
1112+
1113+
11041114
class Maximum(torch.nn.Module):
11051115
def __init__(self):
11061116
super().__init__()
@@ -1751,16 +1761,6 @@ def forward(self, x):
17511761
)
17521762

17531763

1754-
class MaskedFill(torch.nn.Module):
1755-
def __init__(self):
1756-
super().__init__()
1757-
1758-
def forward(self, attn_mask):
1759-
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1760-
attn_mask == 0, float(0.0)
1761-
)
1762-
1763-
17641764
# Mimi Decoder has 0D tensor which QNN cannot handle.
17651765
class ZeroDimTensor(torch.nn.Module):
17661766
def __init__(self):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,24 @@ def test_qnn_backend_cos(self):
272272
self.lower_module_and_test_output(module, sample_input)
273273

274274
def test_qnn_backend_cumsum(self):
275-
module = CumSum() # noqa: F405
276-
sample_input = (torch.randn(4),)
277-
self.lower_module_and_test_output(module, sample_input)
275+
sample_input = ()
276+
test_comb = [
277+
{
278+
QCOM_MODULE: [CumSum()], # noqa: F405
279+
QCOM_SAMPLE_INPUTS: [
280+
(torch.randn(4),),
281+
(torch.randint(0, 10, size=(4,)),),
282+
],
283+
}
284+
]
285+
286+
index = 0
287+
for comb in test_comb:
288+
for module in comb[QCOM_MODULE]:
289+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
290+
with self.subTest(i=index):
291+
self.lower_module_and_test_output(module, sample_input)
292+
index += 1
278293

279294
def test_qnn_backend_einsum_outer_product(self):
280295
module = EinsumOuterProduct() # noqa: F405
@@ -311,6 +326,12 @@ def test_qnn_backend_element_wise_add(self):
311326
QCOM_MODULE: [AddConstantFloat()], # noqa: F405
312327
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
313328
},
329+
{
330+
QCOM_MODULE: [
331+
AddConstantLong(), # noqa: F405
332+
],
333+
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
334+
},
314335
]
315336

316337
index = 0
@@ -4526,6 +4547,40 @@ def test_retinanet(self):
45264547
else:
45274548
self.assertGreaterEqual(msg["mAP"], 0.6)
45284549

4550+
def test_roberta(self):
4551+
if not self.required_envs([self.sentence_dataset]):
4552+
self.skipTest("missing required envs")
4553+
cmds = [
4554+
"python",
4555+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py",
4556+
"--dataset",
4557+
self.sentence_dataset,
4558+
"--artifact",
4559+
self.artifact_dir,
4560+
"--build_folder",
4561+
self.build_folder,
4562+
"--device",
4563+
self.device,
4564+
"--model",
4565+
self.model,
4566+
"--ip",
4567+
self.ip,
4568+
"--port",
4569+
str(self.port),
4570+
]
4571+
if self.host:
4572+
cmds.extend(["--host", self.host])
4573+
4574+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4575+
with Listener((self.ip, self.port)) as listener:
4576+
conn = listener.accept()
4577+
p.communicate()
4578+
msg = json.loads(conn.recv())
4579+
if "Error" in msg:
4580+
self.fail(msg["Error"])
4581+
else:
4582+
self.assertGreaterEqual(msg["accuracy"], 0.5)
4583+
45294584
def test_squeezenet(self):
45304585
if not self.required_envs([self.image_dataset]):
45314586
self.skipTest("missing required envs")
@@ -5344,6 +5399,11 @@ def setup_environment():
53445399
help="Location for imagenet dataset",
53455400
type=str,
53465401
)
5402+
parser.add_argument(
5403+
"--sentence_dataset",
5404+
help="Location for sentence dataset",
5405+
type=str,
5406+
)
53475407
parser.add_argument(
53485408
"-p",
53495409
"--pretrained_weight",
@@ -5402,6 +5462,7 @@ def setup_environment():
54025462
TestQNN.executorch_root = args.executorch_root
54035463
TestQNN.artifact_dir = args.artifact_dir
54045464
TestQNN.image_dataset = args.image_dataset
5465+
TestQNN.sentence_dataset = args.sentence_dataset
54055466
TestQNN.pretrained_weight = args.pretrained_weight
54065467
TestQNN.model_name = args.model_name
54075468
TestQNN.online_prepare = args.online_prepare

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
sentence_dataset: str = ""
186187
pretrained_weight: str = ""
187188
enable_profile: bool = False
188189
op_package_dir: str = ""

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ list(
3636
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
3737
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
3838
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
3941
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
4042
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4143
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
7-
3. LLAMA3.2 3B (WIP)
7+
3. LLAMA3.2 3B
88

99
We offer the following modes to execute the model:
1010

11-
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
11+
- KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
1212

13-
Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
13+
- Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
1414
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode.
1515
- Prompt processing with AR-N model:
1616
<figure>
@@ -19,6 +19,7 @@ Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache
1919
</figcaption>
2020
</figure>
2121

22+
- Lookahead Mode: Lookahead Mode introduces [lookahead decoding](https://arxiv.org/abs/2402.02057) and uses AR-N model to process prompt to enhance token generation speed. While decoding multiple tokens in a single step is infeasible, an LLM can generate multiple guess tokens in parallel. These guess tokens may fit into future parts of the generated sequence. The lookahead decoder generates and verifies these guess tokens, integrating them into the sequence if suitable. In some cases, it can obtain more than one token in a single step. Result is lossless.
2223

2324
## Instructions
2425
### Note
@@ -127,3 +128,14 @@ You can select the KV Cache update mechanism at runtime by setting the `KV_UPDAT
127128
```bash
128129
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
129130
```
131+
132+
You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
133+
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
134+
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
135+
- `--gcap` (Verification candidates): Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.
136+
137+
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)
138+
139+
```bash
140+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
141+
```

0 commit comments

Comments
 (0)