Skip to content

Commit d6de1fe

Browse files
NathanHBgary149
andauthored
allows better flexibility for litellm endpoints (#549)
* allows better flexibility for litellm * add config file * add doc * add doc * add doc * add doc * add lighteval imgs * add lighteval imgs * add doc * Update docs/source/use-litellm-as-backend.mdx Co-authored-by: Victor Muštar <[email protected]> * Update docs/source/use-litellm-as-backend.mdx --------- Co-authored-by: Victor Muštar <[email protected]>
1 parent fac17bb commit d6de1fe

File tree

6 files changed

+168
-18
lines changed

6 files changed

+168
-18
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
- sections:
1010
- local: saving-and-reading-results
1111
title: Save and read results
12+
- local: use-litellm-as-backend
13+
title: Use LITELLM as backend
1214
- local: using-the-python-api
1315
title: Use the Python API
1416
- local: adding-a-custom-task
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Litellm as backend
2+
3+
Lighteval allows to use litellm, a backend allowing you to call all LLM APIs
4+
using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure,
5+
OpenAI, Groq etc.].
6+
7+
Documentation for available APIs and compatible endpoints can be found [here](https://docs.litellm.ai/docs/).
8+
9+
## Quick use
10+
11+
```bash
12+
lighteval endpoint litellm \
13+
"gpt-3.5-turbo" \
14+
"lighteval|gsm8k|0|0"
15+
```
16+
17+
## Using a config file
18+
19+
Litellm allows generation with any OpenAI compatible endpoint, for example you
20+
can evaluate a model running on a local vllm server.
21+
22+
To do so you will need to use a config file like so:
23+
24+
```yaml
25+
model:
26+
base_params:
27+
model_name: "openai/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
28+
base_url: "URL OF THE ENDPOINT YOU WANT TO USE"
29+
api_key: "" # remove or keep empty as needed
30+
generation:
31+
temperature: 0.5
32+
max_new_tokens: 256
33+
stop_tokens: [""]
34+
top_p: 0.9
35+
seed: 0
36+
repetition_penalty: 1.0
37+
frequency_penalty: 0.0
38+
```
39+
40+
## Use Hugging Face Inference Providers
41+
42+
With this you can also access HuggingFace Inference servers, let's look at how to evaluate DeepSeek-R1-Distill-Qwen-32B.
43+
44+
First, let's look at how to acess the model, we can find this from [the model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B).
45+
46+
Step 1:
47+
48+
![Step 1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lighteval/litellm-guide-2.png)
49+
50+
Step 2:
51+
52+
![Step 2](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lighteval/litellm-guide-1.png)
53+
54+
Great ! Now we can simply copy paste the base_url and our api key to eval our model.
55+
56+
> [!WARNING]
57+
> Do not forget to prepend the provider in the `model_name`. Here we use an
58+
> openai compatible endpoint to the provider is `openai`.
59+
60+
```yaml
61+
model:
62+
base_params:
63+
model_name: "openai/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
64+
base_url: "https://router.huggingface.co/hf-inference/v1"
65+
api_key: "YOUR KEY" # remove or keep empty as needed
66+
generation:
67+
temperature: 0.5
68+
max_new_tokens: 256 # This will overide the default from the tasks config
69+
top_p: 0.9
70+
seed: 0
71+
repetition_penalty: 1.0
72+
frequency_penalty: 0.0
73+
```
74+
75+
And then, we are able to eval our model on any eval available in Lighteval.
76+
77+
```bash
78+
lighteval endpoint litellm \
79+
"examples/model_configs/litellm_model.yaml" \
80+
"lighteval|gsm8k|0|0"
81+
```
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model:
2+
base_params:
3+
model_name: "openai/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
4+
base_url: "https://router.huggingface.co/hf-inference/v1"
5+
generation:
6+
temperature: 0.5
7+
max_new_tokens: 256
8+
stop_tokens: [""]
9+
top_p: 0.9
10+
seed: 0
11+
repetition_penalty: 1.0
12+
frequency_penalty: 0.0

src/lighteval/main_endpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,11 @@ def tgi(
385385
@app.command(rich_help_panel="Evaluation Backends")
386386
def litellm(
387387
# === general ===
388-
model_name: Annotated[
389-
str, Argument(help="The model name to evaluate (has to be available through the litellm API.")
388+
model_args: Annotated[
389+
str,
390+
Argument(
391+
help="config file path for the litellm model, or a comma separated string of model args (model_name={},base_url={},provider={})"
392+
),
390393
],
391394
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
392395
# === Common parameters ===
@@ -462,7 +465,11 @@ def litellm(
462465
# TODO (nathan): better handling of model_args
463466
parallelism_manager = ParallelismManager.NONE
464467

465-
model_config = LiteLLMModelConfig(model=model_name)
468+
if model_args.endswith(".yaml"):
469+
model_config = LiteLLMModelConfig.from_path(model_args)
470+
else:
471+
model_name = model_args.split(",")[0].strip()
472+
model_config = LiteLLMModelConfig(model=model_name)
466473

467474
pipeline_params = PipelineParameters(
468475
launcher_type=parallelism_manager,

src/lighteval/models/litellm_model.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@
2121
# SOFTWARE.
2222

2323
import logging
24-
import os
2524
import time
2625
from concurrent.futures import ThreadPoolExecutor
2726
from dataclasses import dataclass
2827
from typing import Optional
2928

29+
import yaml
3030
from tqdm import tqdm
3131

3232
from lighteval.data import GenerativeTaskDataset
3333
from lighteval.models.abstract_model import LightevalModel
3434
from lighteval.models.endpoints.endpoint_model import ModelInfo
35+
from lighteval.models.model_input import GenerationParameters
3536
from lighteval.models.model_output import (
3637
GenerativeResponse,
3738
LoglikelihoodResponse,
@@ -63,6 +64,32 @@
6364
@dataclass
6465
class LiteLLMModelConfig:
6566
model: str
67+
provider: Optional[str] = None
68+
base_url: Optional[str] = None
69+
api_key: Optional[str] = None
70+
generation_parameters: GenerationParameters = None
71+
72+
def __post_init__(self):
73+
if self.generation_parameters is None:
74+
self.generation_parameters = GenerationParameters()
75+
76+
@classmethod
77+
def from_path(cls, path):
78+
with open(path, "r") as f:
79+
config = yaml.safe_load(f)["model"]
80+
81+
model = config["base_params"]["model_name"]
82+
provider = config["base_params"].get("provider", None)
83+
base_url = config["base_params"].get("base_url", None)
84+
api_key = config["base_params"].get("api_key", None)
85+
generation_parameters = GenerationParameters.from_dict(config)
86+
return cls(
87+
model=model,
88+
provider=provider,
89+
base_url=base_url,
90+
generation_parameters=generation_parameters,
91+
api_key=api_key,
92+
)
6693

6794

6895
class LiteLLMClient(LightevalModel):
@@ -79,15 +106,17 @@ def __init__(self, config, env_config) -> None:
79106
model_dtype=None,
80107
model_size="",
81108
)
82-
self.provider = config.model.split("/")[0]
83-
self.base_url = os.getenv(f"{self.provider.upper()}_BASE_URL", None)
109+
self.model = config.model
110+
self.provider = config.provider or config.model.split("/")[0]
111+
self.base_url = config.base_url
112+
self.api_key = config.api_key
113+
self.generation_parameters = config.generation_parameters
114+
84115
self.API_MAX_RETRY = 5
85116
self.API_RETRY_SLEEP = 3
86117
self.API_RETRY_MULTIPLIER = 2
87118
self.CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits
88-
self.TEMPERATURE = 0.3
89-
self.TOP_P = 0.95
90-
self.model = config.model
119+
91120
self._tokenizer = encode
92121
self.pairwise_tokenization = False
93122
litellm.drop_params = True
@@ -126,18 +155,19 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
126155
kwargs = {
127156
"model": self.model,
128157
"messages": prompt,
129-
"max_completion_tokens": max_new_tokens,
130158
"logprobs": return_logits if self.provider == "openai" else None,
131159
"base_url": self.base_url,
132160
"n": num_samples,
133161
"caching": True,
162+
"api_key": self.api_key,
134163
}
135164
if "o1" in self.model:
136165
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
137166
else:
138-
kwargs["temperature"] = self.TEMPERATURE
139-
kwargs["top_p"] = self.TOP_P
140-
kwargs["stop"] = stop_sequence
167+
kwargs.update(self.generation_parameters.to_litellm_dict())
168+
169+
if kwargs.get("max_completion_tokens", None) is None:
170+
kwargs["max_completion_tokens"] = max_new_tokens
141171

142172
response = litellm.completion(**kwargs)
143173

src/lighteval/models/model_input.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ class GenerationParameters:
3232
length_penalty: Optional[float] = None # vllm, transformers
3333
presence_penalty: Optional[float] = None # vllm
3434

35-
max_new_tokens: Optional[int] = None # vllm, transformers, tgi
35+
max_new_tokens: Optional[int] = None # vllm, transformers, tgi, litellm
3636
min_new_tokens: Optional[int] = None # vllm, transformers
3737

38-
seed: Optional[int] = None # vllm, tgi
39-
stop_tokens: Optional[list[str]] = None # vllm, transformers, tgi
40-
temperature: Optional[float] = None # vllm, transformers, tgi
38+
seed: Optional[int] = None # vllm, tgi litellm
39+
stop_tokens: Optional[list[str]] = None # vllm, transformers, tgi, litellm
40+
temperature: Optional[float] = None # vllm, transformers, tgi, litellm
4141
top_k: Optional[int] = None # vllm, transformers, tgi
4242
min_p: Optional[float] = None # vllm, transformers
43-
top_p: Optional[int] = None # vllm, transformers, tgi
43+
top_p: Optional[int] = None # vllm, transformers, tgi, litellm
4444
truncate_prompt: Optional[bool] = None # vllm, tgi
4545

4646
@classmethod
@@ -59,6 +59,24 @@ def from_dict(cls, config_dict: dict):
5959
"""
6060
return GenerationParameters(**config_dict.get("generation", {}))
6161

62+
def to_litellm_dict(self) -> dict:
63+
"""Selects relevant generation and sampling parameters for litellm models.
64+
Doc: https://docs.litellm.ai/docs/completion/input#input-params-1
65+
66+
Returns:
67+
dict: The parameters to create a litellm.SamplingParams in the model config.
68+
"""
69+
args = {
70+
"max_completion_tokens": self.max_new_tokens,
71+
"stop": self.stop_tokens,
72+
"temperature": self.temperature,
73+
"top_p": self.top_p,
74+
"seed": self.seed,
75+
"repetition_penalty": self.repetition_penalty,
76+
"frequency_penalty": self.frequency_penalty,
77+
}
78+
return {k: v for k, v in args.items() if v is not None}
79+
6280
def to_vllm_dict(self) -> dict:
6381
"""Selects relevant generation and sampling parameters for vllm models.
6482
Doc: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html

0 commit comments

Comments
 (0)