2525import torch .nn .functional as F
2626from lm_eval .api .instance import Instance
2727from lm_eval .models .huggingface import HFLM , TemplateLM
28- from lm_eval .models .utils import get_dtype , stop_sequences_criteria
28+ from lm_eval .models .utils import get_dtype
2929
3030# Local imports
3131from transformers import AutoModelForCausalLM , AutoTokenizer
3232from transformers .generation import GenerationConfig
3333
3434
35- logger = logging .getLogger (__name__ )
35+ eval_logger = logging .getLogger (__name__ )
3636
3737
3838class HabanaModelAdapter (HFLM ):
@@ -100,7 +100,7 @@ def __init__(
100100 )
101101 if "gemma" in getattr (self ._config , "model_type" , "" ):
102102 self .add_bos_token = True
103- logger .info (
103+ eval_logger .info (
104104 f"Model type is '{ self ._config .model_type } ', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
105105 )
106106 self .batch_size_per_gpu = int (args .batch_size )
@@ -170,21 +170,23 @@ def max_length(self) -> int:
170170
171171 @property
172172 def device (self ):
173- # We need to do padding ourselves, otherwise we'll end up with recompilations
174- # Returning 'cpu' to keep tensors on CPU in lm_eval code
175- return "cpu"
173+ return torch .device ("hpu" )
176174
177175 @max_length .setter
178176 def max_length (self , value : int ) -> None :
179177 self ._max_length = value
180178
181179 def find_bucket (self , length : int , key = lambda b , length : b >= length ) -> int :
180+ """
181+ Find the smallest bucket >= length, or add a new one.
182+ """
182183 for b in self .buckets :
183184 if key (b , length ):
184185 return b
185186 new_bucket = length
186187 self .buckets .append (new_bucket )
187188 self .buckets .sort ()
189+ eval_logger .info (f"Added new bucket: { new_bucket } . Buckets are now: { self .buckets } " )
188190 return new_bucket
189191
190192 def _model_call (self , inps : torch .Tensor ) -> torch .Tensor :
@@ -195,13 +197,13 @@ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
195197 if self .options .use_cache and self .options .reuse_cache :
196198 self ._model .allocate_kv_cache (bs , bucket_length + 1 , bucket_length )
197199 padding_length = bucket_length - seq_length
198- inps = F .pad (inps , (0 , padding_length ), value = self ._model .config .pad_token_id )
199- logits = self ._model (inps .to (self .device_ ), ** self .model_inputs )["logits" ].cpu ()
200+ pad_token_id = getattr (self ._model .config , "pad_token_id" , 0 )
201+ inps = F .pad (inps , (0 , padding_length ), value = pad_token_id )
202+ eval_logger .debug (f"Padded input from { seq_length } to { bucket_length } (pad={ padding_length } )" )
203+ logits = self ._model (inps .to (self .device ), ** self .model_inputs )["logits" ]
200204
201205 if self .options .static_shapes and padding_length > 0 :
202206 logits = logits [:, :- padding_length , :]
203- logits = logits .to (torch .float32 )
204-
205207 return logits
206208
207209 def generate_until (self , requests : list [Instance ], disable_tqdm : bool = False ) -> list [str ]:
@@ -217,7 +219,7 @@ def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -
217219
218220 def _model_generate (
219221 self ,
220- context ,
222+ context : torch . Tensor ,
221223 max_length : int ,
222224 stop : list [str ],
223225 ** generation_kwargs : dict [str , Any ],
@@ -226,21 +228,12 @@ def _model_generate(
226228 Patched method
227229 source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L951
228230 """
229- # temperature = 0.0 if not set
230- # if do_sample is false and temp==0.0:
231- # remove temperature, as do_sample=False takes care of this
232- # and we don't want a warning from HF
233231 generation_kwargs ["temperature" ] = generation_kwargs .get ("temperature" , 0.0 )
234232 do_sample = generation_kwargs .get ("do_sample" )
235- # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
236233 if generation_kwargs .get ("temperature" ) == 0.0 and do_sample is None :
237234 generation_kwargs ["do_sample" ] = do_sample = False
238-
239235 if do_sample is False and generation_kwargs .get ("temperature" ) == 0.0 :
240236 generation_kwargs .pop ("temperature" )
241- # build stopping criteria
242- stopping_criteria = stop_sequences_criteria (self .tokenizer , stop , context .shape [1 ], context .shape [0 ])
243- # to avoid graph recompilation
244237 if self .options .static_shapes :
245238 self .options .bucket_internal = True
246239 bucket_length = self .find_bucket (context .shape [1 ])
@@ -254,17 +247,16 @@ def _model_generate(
254247 generation_kwargs ["attention_mask" ], (0 , padding_length ), value = 0
255248 )
256249 # move context & attention_mask to hpu
257- context = context .to ("hpu" )
258- generation_kwargs ["attention_mask" ] = generation_kwargs ["attention_mask" ].to ("hpu" )
250+ context = context .to (self . device )
251+ generation_kwargs ["attention_mask" ] = generation_kwargs ["attention_mask" ].to (self . device )
259252 with torch .autocast (
260- device_type = "hpu" ,
253+ device_type = self . device ,
261254 dtype = self .mixed_precision_dtype ,
262255 enabled = self .mixed_precision_dtype is not None ,
263256 ):
264257 return self .model .generate (
265258 input_ids = context ,
266259 max_new_tokens = max_gen_toks ,
267- stopping_criteria = stopping_criteria ,
268260 pad_token_id = self .tokenizer .pad_token_id ,
269261 use_cache = True ,
270262 hpu_graphs = self .hpu_graphs ,
0 commit comments