Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion eval/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In
if "4o" in model.model:
instance.args[1]["max_tokens"] = min(max_new_tokens, 16384)
elif isinstance(model, lm_eval_models.vllm_causallms.VLLM):
instance.args[1]["max_gen_toks"] = max_new_tokens
# Get prompt from instance.args[0] (the templated string)
prompt = instance.args[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, can you wrap lines 57 to 64 in a try catch?
Also maybe check if prompt_length is extremely long (> max_model_len)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will wrap it in a try catch.

If prompt_length > max_model_len, we could log a warning.
capped_max_new_tokens will be set to 1, and the prompt will be truncated to fit into the context window. We are fine with that logic, right?

prompt_length = len(model.tokenizer.encode(prompt))

# Get max model length from vLLM engine
max_model_len = model.model.llm_engine.model_config.max_model_len

# Calculate max allowed generation tokens (16 token safety buffer)
max_allowed = max_model_len - prompt_length - 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you create a constant and name it instead of using 16 here


# Cap to available space
capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed))

if capped_max_new_tokens < max_new_tokens:
self.logger.warning(
f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} "
f"(prompt: {prompt_length} tokens, model max: {max_model_len})"
)

instance.args[1]["max_gen_toks"] = capped_max_new_tokens
else: # Huggingface
instance.args[1]["max_new_tokens"] = max_new_tokens
return instances
Expand Down