@@ -44,14 +44,15 @@ def __init__(
44
44
tokenizer : Union [
45
45
SentencePieceTokenizer , TiktokenTokenizer , HuggingFaceTokenizer
46
46
],
47
- max_seq_length : Optional [ int ] ,
47
+ max_seq_length : int ,
48
48
ar_len : int ,
49
49
use_kv_cache : bool ,
50
50
get_example_inputs : Callable ,
51
51
kv_updater : Callable ,
52
52
use_i64_token : bool ,
53
53
):
54
54
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
55
+ assert max_seq_length is not None , "max_seq_length must be provided"
55
56
super ().__init__ (
56
57
model = model , tokenizer = tokenizer , max_seq_length = max_seq_length - 1
57
58
)
@@ -119,8 +120,10 @@ def __init__(
119
120
for method in program .execution_plan :
120
121
# Don't use tokenizer.n_words, the numbers are off once calling get_tokenizer()
121
122
if method .name == "get_vocab_size" :
123
+ # pyre-ignore
122
124
self .output_vocab_size = method .values [0 ].val .int_val
123
125
if method .name == "get_max_seq_len" :
126
+ # pyre-ignore
124
127
pte_max_seq_len = method .values [0 ].val .int_val
125
128
assert self .output_vocab_size is not None , "Couldn't find the vocab size"
126
129
assert pte_max_seq_len is not None , "Couldn't find the max_seq_len from pte"
@@ -156,6 +159,7 @@ def __init__(
156
159
)
157
160
self .adb .push (inputs = [], input_list = "" , files = [self .runtime_tokenizer_path ])
158
161
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
162
+ # pyre-ignore
159
163
super ().__init__ (None , tokenizer , max_seq_length - 1 )
160
164
161
165
def _model_call (self , inps ):
@@ -278,6 +282,7 @@ def kv_inference(
278
282
else :
279
283
raise RuntimeError ("Unknown tokenizer" )
280
284
else :
285
+ # pyre-ignore
281
286
token_list = prompt .flatten ().tolist ()
282
287
pos = len (token_list ) if len (token_list ) < ar_len else ar_len
283
288
dtype = torch .int64 if use_i64_token else torch .int32
@@ -359,6 +364,7 @@ def prefill_inference(
359
364
else :
360
365
raise RuntimeError ("Unknown tokenizer" )
361
366
else :
367
+ # pyre-ignore
362
368
token_list = prompt .flatten ().tolist ()
363
369
364
370
pos = len (token_list )
@@ -405,7 +411,7 @@ def graph_module_inference(
405
411
max_seq_len = 512 ,
406
412
kv_updater = smart_mask_updater ,
407
413
use_i64_token = False ,
408
- event_name : str = None ,
414
+ event_name : Optional [ str ] = None ,
409
415
):
410
416
if args .tasks is None :
411
417
if use_kv_cache :
0 commit comments