Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/unit/types/test_chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from tlm.utils.chat_completion_validation import _validate_chat_completion_params


def test_validate_chat_completion_params_allows_valid_openai_keys() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_allows_provider_as_none() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_requires_messages() -> None:
params = {"model": "gpt-4.1-mini"}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "openai_args must include the following parameter(s): messages" in str(exc_info.value)


def test_validate_chat_completion_params_requires_messages_list() -> None:
params = {"messages": "not-a-list"}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "`messages` must be provided as a list" in str(exc_info.value)


def test_validate_chat_completion_params_requires_message_dict() -> None:
params = {"messages": ["not-a-dict"]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0] must be a dictionary" in str(exc_info.value)


def test_validate_chat_completion_params_requires_role_and_content_strings() -> None:
params = {"messages": [{"role": 123, "content": None}]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0]['role']" in str(exc_info.value)


def test_validate_chat_completion_params_allows_function_call_without_content() -> None:
params = {
"messages": [
{
"role": "assistant",
"content": None,
"function_call": {"name": "foo", "arguments": '{"bar": 1}'},
}
]
}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_requires_content_when_no_function_call() -> None:
params = {"messages": [{"role": "assistant", "content": None}]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0]['content'] must be a string." in str(exc_info.value)
2 changes: 2 additions & 0 deletions tlm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tlm.config.presets import WorkflowType
from tlm.inference import InferenceResult, tlm_inference
from tlm.types import SemanticEval, CompletionParams
from tlm.utils.chat_completion_validation import _validate_chat_completion_params


async def inference(
Expand All @@ -14,6 +15,7 @@ async def inference(
evals: list[SemanticEval] | None = None,
config_input: ConfigInput = ConfigInput(),
) -> InferenceResult:
_validate_chat_completion_params(openai_args)
workflow_type = WorkflowType.from_inference_params(
openai_args=openai_args,
score=response is not None,
Expand Down
54 changes: 54 additions & 0 deletions tlm/utils/chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Validation helpers for chat completion parameter dictionaries."""

from typing import Any, Callable, FrozenSet, Mapping

from tlm.types.base import CompletionParams


ParamValidator = Callable[[Any], None]

REQUIRED_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset({"messages"})


def _validate_messages_param(messages: Any) -> None:
"""Validate the shape of a `messages` param for chat completions."""

if not isinstance(messages, list):
raise ValueError("`messages` must be provided as a list of message dictionaries.")

for index, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"messages[{index}] must be a dictionary.")

role = message.get("role")
content = message.get("content")

if role is None or not isinstance(role, str):
raise ValueError(f"messages[{index}]['role'] must be a non-empty string.")

if content is None or not isinstance(content, str):
function_call = message.get("function_call")
if role != "assistant":
raise ValueError(f"Non-assistant message at index {index} must have content.")
if function_call is None:
raise ValueError(f"Assistant message at index {index} must have content or a function call.")


REQUIRED_PARAM_VALIDATORS: Mapping[str, ParamValidator] = {
"messages": _validate_messages_param,
}


def _validate_chat_completion_params(params: CompletionParams) -> None: # type: ignore
"""Ensure only supported chat completion params are passed into inference."""

missing_required = [param for param in REQUIRED_CHAT_COMPLETION_PARAMS if param not in params]
if missing_required:
required_str = ", ".join(sorted(REQUIRED_CHAT_COMPLETION_PARAMS))
raise ValueError(f"openai_args must include the following parameter(s): {required_str}")

for param in REQUIRED_CHAT_COMPLETION_PARAMS:
validator = REQUIRED_PARAM_VALIDATORS.get(param)
if validator is None:
continue
validator(params[param])
Loading