Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/unit/types/test_chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from tlm.utils.chat_completion_validation import _validate_chat_completion_params


def test_validate_chat_completion_params_allows_valid_openai_keys() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_allows_provider_as_none() -> None:
params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_requires_messages() -> None:
params = {"model": "gpt-4.1-mini"}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "openai_args must include the following parameter(s): messages" in str(exc_info.value)


def test_validate_chat_completion_params_requires_messages_list() -> None:
params = {"messages": "not-a-list"}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "`messages` must be provided as a list" in str(exc_info.value)


def test_validate_chat_completion_params_requires_message_dict() -> None:
params = {"messages": ["not-a-dict"]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0] must be a dictionary" in str(exc_info.value)


def test_validate_chat_completion_params_requires_role_and_content_strings() -> None:
params = {"messages": [{"role": 123, "content": None}]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0]['role']" in str(exc_info.value)


def test_validate_chat_completion_params_allows_function_call_without_content() -> None:
params = {
"messages": [
{
"role": "assistant",
"content": None,
"function_call": {"name": "foo", "arguments": '{"bar": 1}'},
}
]
}

_validate_chat_completion_params(params)


def test_validate_chat_completion_params_requires_content_when_no_function_call() -> None:
params = {"messages": [{"role": "assistant", "content": None}]}

with pytest.raises(ValueError) as exc_info:
_validate_chat_completion_params(params)

assert "messages[0]['content'] must be a string." in str(exc_info.value)
2 changes: 2 additions & 0 deletions tlm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tlm.config.presets import WorkflowType
from tlm.inference import InferenceResult, tlm_inference
from tlm.types import SemanticEval, CompletionParams
from tlm.utils.chat_completion_validation import _validate_chat_completion_params


async def inference(
Expand All @@ -14,6 +15,7 @@ async def inference(
evals: list[SemanticEval] | None = None,
config_input: ConfigInput = ConfigInput(),
) -> InferenceResult:
_validate_chat_completion_params(openai_args)
workflow_type = WorkflowType.from_inference_params(
openai_args=openai_args,
score=response is not None,
Expand Down
53 changes: 53 additions & 0 deletions tlm/utils/chat_completion_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Validation helpers for chat completion parameter dictionaries."""

# from litellm import get_supported_openai_params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete unused comment?

from typing import Any, Callable, FrozenSet, Mapping

from tlm.types.base import CompletionParams


ParamValidator = Callable[[Any], None]

REQUIRED_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset({"messages"})


def _validate_messages_param(messages: Any) -> None:
"""Validate the shape of a `messages` param for chat completions."""

if not isinstance(messages, list):
raise ValueError("`messages` must be provided as a list of message dictionaries.")

for index, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"messages[{index}] must be a dictionary with 'role' and 'content'.")

role = message.get("role")
content = message.get("content")

if role is None or not isinstance(role, str):
raise ValueError(f"messages[{index}]['role'] must be a non-empty string.")

if content is None or not isinstance(content, str):
function_call = message.get("function_call")
if role != "assistant" or function_call is None:
raise ValueError(f"messages[{index}]['content'] must be a string.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this error message be more specific/helpful? Not sure the best phrasing but something like this. Also I can't tell but it seems like there might be a logical error in the condition - is it never okay for function_call to be None if the content is None? Would be helpful to explain that in the error msg, probably by splitting the conditions

Suggested change
if role != "assistant" or function_call is None:
raise ValueError(f"messages[{index}]['content'] must be a string.")
if role != "assistant":
raise ValueError(f"Non-assistant message at index {index} must have content.")
if function_call is None:
raise ValueError(f"messages[{index}] must have either content or function_call key.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this logic after referring to litellm's required field explanation: "content: string or list[dict] or null - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.", I'll split the the if logic up with more fitting err msg



REQUIRED_PARAM_VALIDATORS: Mapping[str, ParamValidator] = {
"messages": _validate_messages_param,
}


def _validate_chat_completion_params(params: CompletionParams) -> None: # type: ignore
"""Ensure only supported chat completion params are passed into inference."""

missing_required = [param for param in REQUIRED_CHAT_COMPLETION_PARAMS if param not in params]
if missing_required:
required_str = ", ".join(sorted(REQUIRED_CHAT_COMPLETION_PARAMS))
raise ValueError(f"openai_args must include the following parameter(s): {required_str}")

for param in REQUIRED_CHAT_COMPLETION_PARAMS:
validator = REQUIRED_PARAM_VALIDATORS.get(param)
if validator is None:
continue
validator(params[param])
Loading