Skip to content

Commit 41e56ec

Browse files
committed
add extra_body to ModelProvider
1 parent fdbc012 commit 41e56ec

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

src/data_designer/config/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ class ModelProvider(ConfigBase):
221221
name: str
222222
endpoint: str
223223
provider_type: str = "openai"
224-
api_key: str | None = None
224+
api_key: Optional[str] = None
225+
extra_body: Optional[dict[str, Any]] = None
225226

226227

227228
def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:

src/data_designer/engine/models/facade.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
1212
from litellm.types.utils import ModelResponse
1313

14-
from data_designer.config.models import ModelConfig
14+
from data_designer.config.models import ModelConfig, ModelProvider
1515
from data_designer.engine.model_provider import ModelProviderRegistry
1616
from data_designer.engine.models.errors import (
1717
GenerationValidationFailureError,
@@ -45,9 +45,13 @@ def __init__(
4545
def model_name(self) -> str:
4646
return self._model_config.model
4747

48+
@property
49+
def model_provider(self) -> ModelProvider:
50+
return self._model_provider_registry.get_provider(self._model_config.provider)
51+
4852
@property
4953
def model_provider_name(self) -> str:
50-
return self._model_provider_registry.get_provider(self._model_config.provider).name
54+
return self.model_provider.name
5155

5256
@property
5357
def model_alias(self) -> str:
@@ -63,6 +67,8 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool =
6367
extra={"model": self.model_name, "messages": messages, "sensitive": True},
6468
)
6569
response = None
70+
if self.model_provider.extra_body:
71+
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
6672
try:
6773
response = self._router.completion(self.model_name, messages, **kwargs)
6874
logger.debug(

tests/engine/models/test_facade.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,27 @@ def mock_completion(model_name, messages, **kwargs):
148148

149149
assert result == stub_expected_response
150150
assert captured_kwargs == kwargs
151+
152+
153+
@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
154+
def test_completion_with_extra_body(mock_router_completion, stub_model_facade):
155+
messages = [{"role": "user", "content": "test"}]
156+
157+
# completion call has no extra body argument and provider has no extra body
158+
_ = stub_model_facade.completion(messages)
159+
assert len(mock_router_completion.call_args) == 2
160+
assert mock_router_completion.call_args[0][1] == "stub-model-text"
161+
assert mock_router_completion.call_args[0][2] == messages
162+
163+
# completion call has no extra body argument and provider has extra body.
164+
# Should pull extra body from model provider
165+
custom_extra_body = {"some_custom_key": "some_custom_value"}
166+
stub_model_facade.model_provider.extra_body = custom_extra_body
167+
_ = stub_model_facade.completion(messages)
168+
assert mock_router_completion.call_args[1] == {"extra_body": custom_extra_body}
169+
170+
# completion call has extra body argument and provider has extra body.
171+
# Should merge the two with provider extra body taking precedence
172+
completion_extra_body = {"some_completion_key": "some_completion_value", "some_custom_key": "some_different_value"}
173+
_ = stub_model_facade.completion(messages, extra_body=completion_extra_body)
174+
assert mock_router_completion.call_args[1] == {"extra_body": {**completion_extra_body, **custom_extra_body}}

0 commit comments

Comments
 (0)