File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -73,12 +73,11 @@ def main() -> None:
7373 print (f"{ model .config } " )
7474 print (f"{ model .generation_config } " )
7575
76- tokenizer = AutoTokenizer .from_pretrained (args .hf_model_repo )
77- input_ids = tokenizer (["" ], return_tensors = "pt" ).to (device )["input_ids" ]
76+ input_ids = torch .tensor ([[1 ]], dtype = torch .long )
7877 cache_position = torch .tensor ([0 ], dtype = torch .long )
7978
8079 def _get_constant_methods (model : PreTrainedModel ):
81- return {
80+ metadata = {
8281 "get_dtype" : 5 if model .config .torch_dtype == torch .float16 else 6 ,
8382 "get_bos_id" : model .config .bos_token_id ,
8483 "get_eos_id" : model .config .eos_token_id ,
@@ -90,6 +89,7 @@ def _get_constant_methods(model: PreTrainedModel):
9089 "get_vocab_size" : model .config .vocab_size ,
9190 "use_kv_cache" : model .generation_config .use_cache ,
9291 }
92+ return {k : v for k , v in metadata .items () if v is not None }
9393
9494 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
9595
You can’t perform that action at this time.
0 commit comments