Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/unit/types/test_chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from tlm.utils.chat_completion_validation import _validate_chat_completion_params


def test_validate_chat_completion_params_allows_valid_openai_keys() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_allows_provider_as_none() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_requires_messages() -> None:
params = {"model": "gpt-4.1-mini"}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "openai_args must include the following parameter(s): messages" in str(exc_info.value)
2 changes: 2 additions & 0 deletions tlm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tlm.config.presets import WorkflowType
from tlm.inference import InferenceResult, tlm_inference
from tlm.types import SemanticEval, CompletionParams
from tlm.utils.chat_completion_validation import _validate_chat_completion_params


async def inference(
Expand All @@ -14,6 +15,7 @@ async def inference(
evals: list[SemanticEval] | None = None,
config_input: ConfigInput = ConfigInput(),
) -> InferenceResult:
_validate_chat_completion_params(openai_args)
workflow_type = WorkflowType.from_inference_params(
openai_args=openai_args,
score=response is not None,
Expand Down
21 changes: 21 additions & 0 deletions tlm/utils/chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Validation helpers for chat completion parameter dictionaries."""

from __future__ import annotations

# from litellm import get_supported_openai_params
from tlm.types.base import CompletionParams
from typing import FrozenSet


REQUIRED_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset({"messages"})


def _validate_chat_completion_params(params: CompletionParams) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also check that messages is non-empty? (i'm not sure if a chat completion can be generated without any messages)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and can we also validate that the message items are properly formatted, e.g. include role and content?

https://docs.litellm.ai/docs/completion/input#required-fields

"""Ensure only supported chat completion params are passed into inference."""

missing_required = [param for param in REQUIRED_CHAT_COMPLETION_PARAMS if param not in params]
if missing_required:
required_str = ", ".join(sorted(REQUIRED_CHAT_COMPLETION_PARAMS))
raise ValueError(f"openai_args must include the following parameter(s): {required_str}")

return
Loading