Skip to content

Commit 5120c58

Browse files
authored
change inference provider model config to inherit from ModelConfig base class (#678)
1 parent 8e84e1b commit 5120c58

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/lighteval/models/endpoints/inference_providers_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222

2323
import asyncio
2424
import logging
25-
from dataclasses import field
2625
from typing import Any, List, Optional
2726

2827
import yaml
2928
from huggingface_hub import AsyncInferenceClient, ChatCompletionOutput
30-
from pydantic import BaseModel, NonNegativeInt
29+
from pydantic import NonNegativeInt
3130
from tqdm import tqdm
3231
from tqdm.asyncio import tqdm as async_tqdm
3332
from transformers import AutoTokenizer
@@ -41,6 +40,7 @@
4140
LoglikelihoodResponse,
4241
LoglikelihoodSingleTokenResponse,
4342
)
43+
from lighteval.models.utils import ModelConfig
4444
from lighteval.tasks.requests import (
4545
GreedyUntilRequest,
4646
LoglikelihoodRequest,
@@ -52,7 +52,7 @@
5252
logger = logging.getLogger(__name__)
5353

5454

55-
class InferenceProvidersModelConfig(BaseModel):
55+
class InferenceProvidersModelConfig(ModelConfig):
5656
"""Configuration for InferenceProvidersClient.
5757
5858
Args:
@@ -68,7 +68,6 @@ class InferenceProvidersModelConfig(BaseModel):
6868
timeout: int | None = None
6969
proxies: Any | None = None
7070
parallel_calls_count: NonNegativeInt = 10
71-
generation_parameters: GenerationParameters = field(default_factory=GenerationParameters)
7271

7372
@classmethod
7473
def from_path(cls, path):

0 commit comments

Comments
 (0)