File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -74,11 +74,11 @@ def main() -> None:
7474    print (f"{ model .generation_config }  " )
7575
7676    tokenizer  =  AutoTokenizer .from_pretrained (args .hf_model_repo )
77-     input_ids  =  tokenizer ([ "" ],  return_tensors = "pt" ). to ( device )[ "input_ids" ] 
77+     input_ids  =  torch . tensor ([[ 1 ]],  dtype = torch . long ) 
7878    cache_position  =  torch .tensor ([0 ], dtype = torch .long )
7979
8080    def  _get_constant_methods (model : PreTrainedModel ):
81-         return  {
81+         metadata   =  {
8282            "get_dtype" : 5  if  model .config .torch_dtype  ==  torch .float16  else  6 ,
8383            "get_bos_id" : model .config .bos_token_id ,
8484            "get_eos_id" : model .config .eos_token_id ,
@@ -90,6 +90,7 @@ def _get_constant_methods(model: PreTrainedModel):
9090            "get_vocab_size" : model .config .vocab_size ,
9191            "use_kv_cache" : model .generation_config .use_cache ,
9292        }
93+         return  {k : v  for  k , v  in  metadata .items () if  v  is  not   None }
9394
9495    with  torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
9596
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments