@@ -44,14 +44,15 @@ def __init__(
4444 tokenizer : Union [
4545 SentencePieceTokenizer , TiktokenTokenizer , HuggingFaceTokenizer
4646 ],
47- max_seq_length : Optional [ int ] ,
47+ max_seq_length : int ,
4848 ar_len : int ,
4949 use_kv_cache : bool ,
5050 get_example_inputs : Callable ,
5151 kv_updater : Callable ,
5252 use_i64_token : bool ,
5353 ):
5454 # n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
55+ assert max_seq_length is not None , "max_seq_length must be provided"
5556 super ().__init__ (
5657 model = model , tokenizer = tokenizer , max_seq_length = max_seq_length - 1
5758 )
@@ -119,8 +120,10 @@ def __init__(
119120 for method in program .execution_plan :
120121 # Don't use tokenizer.n_words, the numbers are off once calling get_tokenizer()
121122 if method .name == "get_vocab_size" :
123+ # pyre-ignore
122124 self .output_vocab_size = method .values [0 ].val .int_val
123125 if method .name == "get_max_seq_len" :
126+ # pyre-ignore
124127 pte_max_seq_len = method .values [0 ].val .int_val
125128 assert self .output_vocab_size is not None , "Couldn't find the vocab size"
126129 assert pte_max_seq_len is not None , "Couldn't find the max_seq_len from pte"
@@ -156,6 +159,7 @@ def __init__(
156159 )
157160 self .adb .push (inputs = [], input_list = "" , files = [self .runtime_tokenizer_path ])
158161 # n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
162+ # pyre-ignore
159163 super ().__init__ (None , tokenizer , max_seq_length - 1 )
160164
161165 def _model_call (self , inps ):
@@ -278,6 +282,7 @@ def kv_inference(
278282 else :
279283 raise RuntimeError ("Unknown tokenizer" )
280284 else :
285+ # pyre-ignore
281286 token_list = prompt .flatten ().tolist ()
282287 pos = len (token_list ) if len (token_list ) < ar_len else ar_len
283288 dtype = torch .int64 if use_i64_token else torch .int32
@@ -359,6 +364,7 @@ def prefill_inference(
359364 else :
360365 raise RuntimeError ("Unknown tokenizer" )
361366 else :
367+ # pyre-ignore
362368 token_list = prompt .flatten ().tolist ()
363369
364370 pos = len (token_list )
@@ -405,7 +411,7 @@ def graph_module_inference(
405411 max_seq_len = 512 ,
406412 kv_updater = smart_mask_updater ,
407413 use_i64_token = False ,
408- event_name : str = None ,
414+ event_name : Optional [ str ] = None ,
409415):
410416 if args .tasks is None :
411417 if use_kv_cache :
0 commit comments