Skip to content

Commit 4b06b94

Browse files
authored
[FIX] Fixes vllm backend (#317)
* fix vllm * fix long loglikelihood context in vllm backend * removes the need for pytest hook function * fix model max length
1 parent 7295c78 commit 4b06b94

File tree

3 files changed

+48
-36
lines changed

3 files changed

+48
-36
lines changed

src/lighteval/models/base_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def add_special_tokens(self):
126126
def max_length(self) -> int:
127127
return self._max_length
128128

129-
def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Optional[dict], Optional[str]]:
129+
def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]:
130130
"""Compute all the parameters related to model_parallel"""
131131
if not is_accelerate_available():
132132
return False, None, None
@@ -147,7 +147,7 @@ def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Option
147147
f"the number of local processes is {self.num_local_processes} "
148148
f"and the number of GPUs is {len(max_memory_all_gpus)}"
149149
)
150-
if model_parallel:
150+
if model_parallel is True:
151151
max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus
152152
if "cpu" in max_memory_all_gpus:
153153
del max_memory_all_gpus["cpu"]
@@ -569,7 +569,6 @@ def greedy_until(
569569
if max_new_tokens is None: # If generation size is not set, we go all the way
570570
max_new_tokens = self.max_length - context_size
571571
else:
572-
print(self.max_length, context_size, max_new_tokens)
573572
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
574573
if max_new_tokens < 1:
575574
max_new_tokens = 1

src/lighteval/models/model_config.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class BaseModelConfig:
118118
def __post_init__(self):
119119
# Making sure this parameter is a boolean
120120
self.multichoice_continuations_start_space = boolstring_to_bool(self.multichoice_continuations_start_space)
121+
self.model_parallel = boolstring_to_bool(self.model_parallel)
122+
self.compile = boolstring_to_bool(self.compile)
121123

122124
if self.quantization_config is not None and not is_bnb_available():
123125
raise ImportError(NO_BNB_ERROR_MSG)
@@ -209,19 +211,21 @@ def init_configs(self, env_config: EnvConfig):
209211
@dataclass
210212
class VLLMModelConfig:
211213
pretrained: str
212-
gpu_memory_utilisation: float = 0.8
213-
batch_size: int = -1
214-
revision: str = "main"
214+
gpu_memory_utilisation: float = 0.9 # lower this if you are running out of memory
215+
revision: str = "main" # revision of the model
215216
dtype: str | None = None
216-
tensor_parallel_size: int = 1
217-
data_parallel_size: int = 1
218-
max_model_length: int = 1024
217+
tensor_parallel_size: int = 1 # how many GPUs to use for tensor parallelism
218+
pipeline_parallel_size: int = 1 # how many GPUs to use for pipeline parallelism
219+
data_parallel_size: int = 1 # how many GPUs to use for data parallelism
220+
max_model_length: int | None = None # maximum length of the model, ussually infered automatically. reduce this if you encouter OOM issues, 4096 is usually enough
219221
swap_space: int = 4 # CPU swap space size (GiB) per GPU.
220222
seed: int = 1234
221223
trust_remote_code: bool = False
222224
use_chat_template: bool = False
223225
add_special_tokens: bool = True
224-
multichoice_continuations_start_space: bool = True
226+
multichoice_continuations_start_space: bool = (
227+
True # whether to add a space at the start of each continuation in multichoice generation
228+
)
225229
subfolder: Optional[str] = None
226230

227231

src/lighteval/models/vllm_model.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,17 @@ def __init__(
6868
):
6969
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
7070
self._config = config
71-
self._batch_size = config.batch_size
72-
self._max_length = self._init_max_length(config.max_model_length)
7371
self.use_chat_template = config.use_chat_template
7472
self.data_parallel_size = int(config.data_parallel_size)
7573

7674
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
7775
self._tokenizer = self._create_auto_tokenizer(config, env_config)
7876

77+
if config.max_model_length is not None:
78+
self._max_length = int(config.max_model_length)
79+
else:
80+
self._max_length = self.tokenizer.model_max_length or self.tokenizer.max_position_embeddings
81+
7982
# If model_parallel is not set we compare the number of processes with the number of GPUs
8083
self.model = self._create_auto_model(config, env_config)
8184

@@ -120,12 +123,13 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) ->
120123
"""
121124
self.model_args = {
122125
"model": config.pretrained,
123-
"gpu_memory_utilization": float(0.8),
126+
"gpu_memory_utilization": float(config.gpu_memory_utilisation),
124127
"revision": config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
125128
"dtype": config.dtype,
126129
"trust_remote_code": config.trust_remote_code,
127-
"tensor_parallel_size": int(1),
128-
"max_model_len": int(self._max_length) if self._max_length else None,
130+
"tensor_parallel_size": int(config.tensor_parallel_size),
131+
"pipeline_parallel_size": int(config.pipeline_parallel_size),
132+
"max_model_len": self._max_length,
129133
"swap_space": 4,
130134
"seed": 1234,
131135
}
@@ -227,30 +231,33 @@ def greedy_until(
227231
# of losing some meaning, or have some generations that are exceedingly short?
228232
# The choice we go for here is to avoid truncating the prompt if we can, since it
229233
# should have been managed by the prompt creator/few shot manager if requested by the user.
230-
context_size = len(tokenized["input_ids"][0])
231-
if context_size > self.max_length:
232-
hlog_warn(
233-
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
234-
+ str({dataset[0].task_name})
235-
+ ". This is likely to lead to some errors." # noqa C401
236-
)
237-
# There will be truncation of at least one sample, maximum generation size will be one
238-
max_new_tokens = 1
239-
else: # We can't allow generation of more than max_length
240-
if max_new_tokens is None: # If generation size is not set, we go all the way
241-
max_new_tokens = self.max_length - context_size
242-
else:
243-
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
234+
inputs = tokenized["input_ids"]
235+
context_size = len(inputs[0])
236+
237+
# left truncate the inputs to the maximum length
238+
if max_new_tokens is not None:
239+
if context_size + max_new_tokens > self.max_length:
240+
hlog_warn(
241+
f"{context_size + max_new_tokens=} which is greather than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
242+
)
243+
context_size = self.max_length - max_new_tokens
244+
inputs = [input[-context_size:] for input in inputs]
245+
else:
246+
if context_size > self.max_length:
247+
hlog_warn(
248+
f"{context_size=} which is greather than {self.max_length=}. Truncating context to {self.max_length} tokens."
249+
)
250+
context_size = self.max_length
251+
inputs = [input[-context_size:] for input in inputs]
244252

245253
vllm_outputs = self._generate(
246-
inputs=tokenized["input_ids"],
254+
inputs=inputs,
247255
max_new_tokens=max_new_tokens,
248256
stop_tokens=stop_tokens,
249257
returns_logits=returns_logits,
250258
num_samples=num_samples,
251259
)
252260

253-
print(f"{len(vllm_outputs)} vllm_outputs")
254261
for vllm_output in vllm_outputs:
255262
output_token_ids = [outputs.token_ids for outputs in vllm_output.outputs]
256263
logprobs = [output.logprobs for output in vllm_output.outputs] or []
@@ -345,19 +352,21 @@ def _loglikelihood_tokens(
345352

346353
for _ in tqdm(dataset.splits_start_end_iterator()):
347354
# the last token is an eos token, so we don't need to add it
348-
inputs = [
349-
dataset[i].tokenized_context + dataset[i].tokenized_continuation[:-1] for i in range(len(dataset))
350-
]
355+
inputs = [dataset[i].tokenized_context + dataset[i].tokenized_continuation for i in range(len(dataset))]
356+
# Left truncate the inputs to the maximum length
357+
inputs = [input[-self.max_length :] for input in inputs]
351358
outputs = self._generate(inputs, generate=False)
352359

353360
for output, input in zip(outputs, dataset):
354361
continuation_logprobs = []
355-
for token, logprobs in zip(input.tokenized_continuation[-2::-1], output.prompt_logprobs[::-1]):
362+
for token, logprobs in zip(input.tokenized_continuation[::-1], output.prompt_logprobs[::-1]):
356363
continuation_logprobs.append(logprobs[token])
357364
bool_score = all(logprob.rank == 1 for logprob in continuation_logprobs)
358365
continuation_logprobs = [logprob.logprob for logprob in continuation_logprobs]
359366
answer = LoglikelihoodResponse(
360-
result=(sum(continuation_logprobs), bool_score if return_bool_score else None)
367+
input_tokens=input.tokenized_context + input.tokenized_continuation,
368+
generated_tokens=input.tokenized_continuation,
369+
result=(sum(continuation_logprobs), bool_score if return_bool_score else None),
361370
)
362371
res.append(answer)
363372

0 commit comments

Comments
 (0)