Skip to content

Commit af420f5

Browse files
authored
fix llama buck build (#13169)
Summary: Some recent changes break the llama buck build Differential Revision: D79753385
1 parent 18098a4 commit af420f5

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

examples/models/llama/evaluate/eager_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
from lm_eval.models.huggingface import HFLM as eval_wrapper
13+
from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer
1314
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
1415
from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken
1516

@@ -24,7 +25,7 @@ class EagerEvalWrapper(eval_wrapper):
2425
def __init__(
2526
self,
2627
model: nn.Module,
27-
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
28+
tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer],
2829
max_seq_length: Optional[int] = None,
2930
use_kv_cache: bool = False,
3031
):

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,30 @@ python_library(
1515
],
1616
)
1717

18+
python_library(
19+
name = "decoder_utils",
20+
srcs = [
21+
"decoder_utils.py",
22+
],
23+
deps = [
24+
"//caffe2:torch",
25+
"//executorch/examples/models/llama:eval_library",
26+
],
27+
)
28+
29+
python_library(
30+
name = "decoder_constants",
31+
srcs = [
32+
"decoder_constants.py",
33+
],
34+
)
35+
1836
python_library(
1937
name = "llama_lib",
2038
srcs = ["llama.py"],
2139
deps = [
40+
":decoder_constants",
41+
":decoder_utils",
2242
"//executorch/examples/models/llama:source_transformation",
2343
"//caffe2:torch",
2444
"//executorch/backends/qualcomm/partition:partition",

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ def __init__(
4444
tokenizer: Union[
4545
SentencePieceTokenizer, TiktokenTokenizer, HuggingFaceTokenizer
4646
],
47-
max_seq_length: Optional[int],
47+
max_seq_length: int,
4848
ar_len: int,
4949
use_kv_cache: bool,
5050
get_example_inputs: Callable,
5151
kv_updater: Callable,
5252
use_i64_token: bool,
5353
):
5454
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
55+
assert max_seq_length is not None, "max_seq_length must be provided"
5556
super().__init__(
5657
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length - 1
5758
)
@@ -119,8 +120,10 @@ def __init__(
119120
for method in program.execution_plan:
120121
# Don't use tokenizer.n_words, the numbers are off once calling get_tokenizer()
121122
if method.name == "get_vocab_size":
123+
# pyre-ignore
122124
self.output_vocab_size = method.values[0].val.int_val
123125
if method.name == "get_max_seq_len":
126+
# pyre-ignore
124127
pte_max_seq_len = method.values[0].val.int_val
125128
assert self.output_vocab_size is not None, "Couldn't find the vocab size"
126129
assert pte_max_seq_len is not None, "Couldn't find the max_seq_len from pte"
@@ -156,6 +159,7 @@ def __init__(
156159
)
157160
self.adb.push(inputs=[], input_list="", files=[self.runtime_tokenizer_path])
158161
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
162+
# pyre-ignore
159163
super().__init__(None, tokenizer, max_seq_length - 1)
160164

161165
def _model_call(self, inps):
@@ -278,6 +282,7 @@ def kv_inference(
278282
else:
279283
raise RuntimeError("Unknown tokenizer")
280284
else:
285+
# pyre-ignore
281286
token_list = prompt.flatten().tolist()
282287
pos = len(token_list) if len(token_list) < ar_len else ar_len
283288
dtype = torch.int64 if use_i64_token else torch.int32
@@ -359,6 +364,7 @@ def prefill_inference(
359364
else:
360365
raise RuntimeError("Unknown tokenizer")
361366
else:
367+
# pyre-ignore
362368
token_list = prompt.flatten().tolist()
363369

364370
pos = len(token_list)
@@ -405,7 +411,7 @@ def graph_module_inference(
405411
max_seq_len=512,
406412
kv_updater=smart_mask_updater,
407413
use_i64_token=False,
408-
event_name: str = None,
414+
event_name: Optional[str] = None,
409415
):
410416
if args.tasks is None:
411417
if use_kv_cache:

0 commit comments

Comments
 (0)