Skip to content

Commit 472e93d

Browse files
authored
lm-eval polishing and speed-up (#2361)
1 parent 72dea90 commit 472e93d

File tree

2 files changed

+17
-25
lines changed

2 files changed

+17
-25
lines changed

examples/text-generation/model_adapter.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
import torch.nn.functional as F
2626
from lm_eval.api.instance import Instance
2727
from lm_eval.models.huggingface import HFLM, TemplateLM
28-
from lm_eval.models.utils import get_dtype, stop_sequences_criteria
28+
from lm_eval.models.utils import get_dtype
2929

3030
# Local imports
3131
from transformers import AutoModelForCausalLM, AutoTokenizer
3232
from transformers.generation import GenerationConfig
3333

3434

35-
logger = logging.getLogger(__name__)
35+
eval_logger = logging.getLogger(__name__)
3636

3737

3838
class HabanaModelAdapter(HFLM):
@@ -100,7 +100,7 @@ def __init__(
100100
)
101101
if "gemma" in getattr(self._config, "model_type", ""):
102102
self.add_bos_token = True
103-
logger.info(
103+
eval_logger.info(
104104
f"Model type is '{self._config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
105105
)
106106
self.batch_size_per_gpu = int(args.batch_size)
@@ -170,21 +170,23 @@ def max_length(self) -> int:
170170

171171
@property
172172
def device(self):
173-
# We need to do padding ourselves, otherwise we'll end up with recompilations
174-
# Returning 'cpu' to keep tensors on CPU in lm_eval code
175-
return "cpu"
173+
return torch.device("hpu")
176174

177175
@max_length.setter
178176
def max_length(self, value: int) -> None:
179177
self._max_length = value
180178

181179
def find_bucket(self, length: int, key=lambda b, length: b >= length) -> int:
180+
"""
181+
Find the smallest bucket >= length, or add a new one.
182+
"""
182183
for b in self.buckets:
183184
if key(b, length):
184185
return b
185186
new_bucket = length
186187
self.buckets.append(new_bucket)
187188
self.buckets.sort()
189+
eval_logger.info(f"Added new bucket: {new_bucket}. Buckets are now: {self.buckets}")
188190
return new_bucket
189191

190192
def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
@@ -195,13 +197,13 @@ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
195197
if self.options.use_cache and self.options.reuse_cache:
196198
self._model.allocate_kv_cache(bs, bucket_length + 1, bucket_length)
197199
padding_length = bucket_length - seq_length
198-
inps = F.pad(inps, (0, padding_length), value=self._model.config.pad_token_id)
199-
logits = self._model(inps.to(self.device_), **self.model_inputs)["logits"].cpu()
200+
pad_token_id = getattr(self._model.config, "pad_token_id", 0)
201+
inps = F.pad(inps, (0, padding_length), value=pad_token_id)
202+
eval_logger.debug(f"Padded input from {seq_length} to {bucket_length} (pad={padding_length})")
203+
logits = self._model(inps.to(self.device), **self.model_inputs)["logits"]
200204

201205
if self.options.static_shapes and padding_length > 0:
202206
logits = logits[:, :-padding_length, :]
203-
logits = logits.to(torch.float32)
204-
205207
return logits
206208

207209
def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -> list[str]:
@@ -217,7 +219,7 @@ def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -
217219

218220
def _model_generate(
219221
self,
220-
context,
222+
context: torch.Tensor,
221223
max_length: int,
222224
stop: list[str],
223225
**generation_kwargs: dict[str, Any],
@@ -226,21 +228,12 @@ def _model_generate(
226228
Patched method
227229
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L951
228230
"""
229-
# temperature = 0.0 if not set
230-
# if do_sample is false and temp==0.0:
231-
# remove temperature, as do_sample=False takes care of this
232-
# and we don't want a warning from HF
233231
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
234232
do_sample = generation_kwargs.get("do_sample")
235-
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
236233
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
237234
generation_kwargs["do_sample"] = do_sample = False
238-
239235
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
240236
generation_kwargs.pop("temperature")
241-
# build stopping criteria
242-
stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, context.shape[1], context.shape[0])
243-
# to avoid graph recompilation
244237
if self.options.static_shapes:
245238
self.options.bucket_internal = True
246239
bucket_length = self.find_bucket(context.shape[1])
@@ -254,17 +247,16 @@ def _model_generate(
254247
generation_kwargs["attention_mask"], (0, padding_length), value=0
255248
)
256249
# move context & attention_mask to hpu
257-
context = context.to("hpu")
258-
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to("hpu")
250+
context = context.to(self.device)
251+
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to(self.device)
259252
with torch.autocast(
260-
device_type="hpu",
253+
device_type=self.device,
261254
dtype=self.mixed_precision_dtype,
262255
enabled=self.mixed_precision_dtype is not None,
263256
):
264257
return self.model.generate(
265258
input_ids=context,
266259
max_new_tokens=max_gen_toks,
267-
stopping_criteria=stopping_criteria,
268260
pad_token_id=self.tokenizer.pad_token_id,
269261
use_cache=True,
270262
hpu_graphs=self.hpu_graphs,

examples/text-generation/run_lm_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def setup_lm_eval_parser():
145145
"--metadata",
146146
type=json.loads,
147147
default=None,
148-
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
148+
help="""JSON string metadata to pass to task configs, for example '{"max_length":1024}'. Will be merged with model_args. Can also be set in task config.""",
149149
)
150150
parser.add_argument(
151151
"--apply_chat_template",

0 commit comments

Comments
 (0)