Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/models/llama/evaluate/eager_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from lm_eval.models.huggingface import HFLM as eval_wrapper
from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken

Expand All @@ -24,7 +25,7 @@ class EagerEvalWrapper(eval_wrapper):
def __init__(
self,
model: nn.Module,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
):
Expand Down
20 changes: 20 additions & 0 deletions examples/qualcomm/oss_scripts/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,30 @@ python_library(
],
)

python_library(
name = "decoder_utils",
srcs = [
"decoder_utils.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:eval_library",
],
)

python_library(
name = "decoder_constants",
srcs = [
"decoder_constants.py",
],
)

python_library(
name = "llama_lib",
srcs = ["llama.py"],
deps = [
":decoder_constants",
":decoder_utils",
"//executorch/examples/models/llama:source_transformation",
"//caffe2:torch",
"//executorch/backends/qualcomm/partition:partition",
Expand Down
10 changes: 8 additions & 2 deletions examples/qualcomm/oss_scripts/llama/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def __init__(
tokenizer: Union[
SentencePieceTokenizer, TiktokenTokenizer, HuggingFaceTokenizer
],
max_seq_length: Optional[int],
max_seq_length: int,
ar_len: int,
use_kv_cache: bool,
get_example_inputs: Callable,
kv_updater: Callable,
use_i64_token: bool,
):
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
assert max_seq_length is not None, "max_seq_length must be provided"
super().__init__(
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length - 1
)
Expand Down Expand Up @@ -119,8 +120,10 @@ def __init__(
for method in program.execution_plan:
# Don't use tokenizer.n_words, the numbers are off once calling get_tokenizer()
if method.name == "get_vocab_size":
# pyre-ignore
self.output_vocab_size = method.values[0].val.int_val
if method.name == "get_max_seq_len":
# pyre-ignore
pte_max_seq_len = method.values[0].val.int_val
assert self.output_vocab_size is not None, "Couldn't find the vocab size"
assert pte_max_seq_len is not None, "Couldn't find the max_seq_len from pte"
Expand Down Expand Up @@ -156,6 +159,7 @@ def __init__(
)
self.adb.push(inputs=[], input_list="", files=[self.runtime_tokenizer_path])
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
# pyre-ignore
super().__init__(None, tokenizer, max_seq_length - 1)

def _model_call(self, inps):
Expand Down Expand Up @@ -278,6 +282,7 @@ def kv_inference(
else:
raise RuntimeError("Unknown tokenizer")
else:
# pyre-ignore
token_list = prompt.flatten().tolist()
pos = len(token_list) if len(token_list) < ar_len else ar_len
dtype = torch.int64 if use_i64_token else torch.int32
Expand Down Expand Up @@ -359,6 +364,7 @@ def prefill_inference(
else:
raise RuntimeError("Unknown tokenizer")
else:
# pyre-ignore
token_list = prompt.flatten().tolist()

pos = len(token_list)
Expand Down Expand Up @@ -405,7 +411,7 @@ def graph_module_inference(
max_seq_len=512,
kv_updater=smart_mask_updater,
use_i64_token=False,
event_name: str = None,
event_name: Optional[str] = None,
):
if args.tasks is None:
if use_kv_cache:
Expand Down
Loading