55# LICENSE file in the root directory of this source tree.
66
77import argparse
8- import logging
98import copy
109import json
11- import torch
12- from lm_eval .evaluator import simple_evaluate
1310
1411from typing import List , Optional , Tuple
1512
1815
1916from executorch .examples .models .llama .eval_llama_lib import (
2017 build_args_parser ,
21- GraphModuleEvalWrapper
18+ GraphModuleEvalWrapper ,
2219)
2320
24- from pytorch_tokenizers import get_tokenizer
25-
2621from executorch .examples .qualcomm .oss_scripts .llama .model .static_llama import (
27- LlamaModel ,
28- ModelArgs ,
22+ LlamaModel ,
23+ ModelArgs ,
2924)
25+ from lm_eval .evaluator import simple_evaluate
26+
27+ from pytorch_tokenizers import get_tokenizer
3028
3129
3230class WrappedLlamaModel (nn .Module ):
33- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = ' cuda' ):
31+ def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = " cuda" ):
3432 super (WrappedLlamaModel , self ).__init__ ()
3533 self .model = model
3634 self .max_seq_len = max_seq_len
3735 self .use_kv_cache = use_kv_cache
3836 self .device = device
3937
40- def forward (self ,
38+ def forward (
39+ self ,
4140 tokens : torch .Tensor ,
4241 input_pos : Optional [torch .Tensor ] = None ,
4342 * args ,
4443 ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
4544 # Pad input if necessary, since LlamaModel requires static shape
4645 if tokens .shape [1 ] != self .max_seq_len :
47- tokens = torch .nn .functional .pad (tokens , (self .max_seq_len - tokens .shape [1 ],0 ))
48- atten_mask = self .model .get_example_inputs (self .use_kv_cache )[1 ].to (device = self .device ).to (dtype = torch .bfloat16 )
46+ tokens = torch .nn .functional .pad (
47+ tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
48+ )
49+ atten_mask = (
50+ self .model .get_example_inputs (self .use_kv_cache )[1 ]
51+ .to (device = self .device )
52+ .to (dtype = torch .bfloat16 )
53+ )
4954 return self .model .forward (tokens , atten_mask , input_pos , * args )
5055
5156
52-
5357def gen_eval_wrapper (model_name , args ):
5458 tokenizer = get_tokenizer (args .tokenizer_path )
5559 with open (args .params ) as f :
@@ -66,7 +70,13 @@ def gen_eval_wrapper(model_name, args):
6670 )
6771 config = prefill_config
6872 use_i64_token = args .embedding_quantize is not None
69- model = LlamaModel (config , ar_len = args .prefill_ar_len , output_new_cache_only = True , output_cache = False , use_i64_token = use_i64_token )
73+ model = LlamaModel (
74+ config ,
75+ ar_len = args .prefill_ar_len ,
76+ output_new_cache_only = True ,
77+ output_cache = False ,
78+ use_i64_token = use_i64_token ,
79+ )
7080 state_dict = torch .load (
7181 args .checkpoint , weights_only = True , map_location = args .device , mmap = True
7282 )
@@ -111,7 +121,9 @@ def permute(w, heads):
111121 model .to (dtype = torch .bfloat16 )
112122 model .to (args .device )
113123
114- wrapped_model = WrappedLlamaModel (model , args .use_kv_cache , args .max_seq_length , args .device )
124+ wrapped_model = WrappedLlamaModel (
125+ model , args .use_kv_cache , args .max_seq_length , args .device
126+ )
115127
116128 return GraphModuleEvalWrapper (
117129 model = wrapped_model ,
@@ -123,7 +135,6 @@ def permute(w, heads):
123135 )
124136
125137
126-
127138def eval_llama (
128139 model_name : str ,
129140 args : argparse .Namespace ,
@@ -166,7 +177,7 @@ def main() -> None:
166177 args .use_kv_cache = False
167178 args .prefill_ar_len = args .max_seq_length
168179
169- args .device = ' cuda' if torch .cuda .is_available () else ' cpu'
180+ args .device = " cuda" if torch .cuda .is_available () else " cpu"
170181
171182 eval_llama (modelname , args )
172183
0 commit comments