Skip to content

Commit b58932a

Browse files
committed
fix: Validate the completion params that are passed as input to the API #3
- added hardcoded validation for openai_args in inference. reference: openai https://platform.openai.com/docs/api-reference/chat/create anthropic https://platform.claude.com/docs/en/api/python/messages/create microsoft https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/chatgpt?view=foundry-classic&tabs=python-secure
1 parent 5bbd4be commit b58932a

File tree

3 files changed

+110
-4
lines changed

3 files changed

+110
-4
lines changed

tlm/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tlm.config.presets import WorkflowType
55
from tlm.inference import InferenceResult, tlm_inference
66
from tlm.types import SemanticEval, CompletionParams
7+
from tlm.utils.chat_completion_validation import _validate_chat_completion_params
78

89

910
async def inference(
@@ -21,6 +22,7 @@ async def inference(
2122
constrain_outputs=config_input.constrain_outputs,
2223
)
2324
config = Config.from_input(config_input, workflow_type)
25+
_validate_chat_completion_params(openai_args, config.provider)
2426
return await tlm_inference(
2527
completion_params=openai_args,
2628
response=response,

tlm/config/provider.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ class APICredentials(BaseModel):
1414
api_base: str | None = None
1515
api_version: str | None = None
1616

17+
OPENAI_PROVIDER = "openai"
18+
BEDROCK_PROVIDER = "bedrock"
19+
GOOGLE_PROVIDER = "google"
20+
AZURE_PROVIDER = "azure"
21+
1722

1823
class ModelProvider(APICredentials):
1924
model: str
@@ -24,13 +29,13 @@ def set_provider_from_model(self):
2429
"""Automatically set provider based on model name if provider is None."""
2530
if self.provider is None:
2631
if self.model in OPENAI_MODELS:
27-
self.provider = "openai"
32+
self.provider = OPENAI_PROVIDER
2833
elif self.model in BEDROCK_MODELS:
29-
self.provider = "bedrock"
34+
self.provider = BEDROCK_PROVIDER
3035
elif self.model in GOOGLE_MODELS:
31-
self.provider = "google"
36+
self.provider = GOOGLE_PROVIDER
3237
elif self.model in AZURE_MODELS:
33-
self.provider = "azure"
38+
self.provider = AZURE_PROVIDER
3439

3540
if self.model in BEDROCK_MODELS:
3641
self.model = BEDROCK_MODEL_TO_INFERENCE_PROFILE_ID[self.model]
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Validation helpers for chat completion parameter dictionaries."""
2+
3+
from __future__ import annotations
4+
from tlm.types.base import CompletionParams
5+
from tlm.config.provider import AZURE_PROVIDER, BEDROCK_PROVIDER, GOOGLE_PROVIDER, OPENAI_PROVIDER
6+
from typing import FrozenSet
7+
8+
VALID_OPENAI_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset(
9+
{
10+
"audio",
11+
"function_call",
12+
"functions",
13+
"frequency_penalty",
14+
"logit_bias",
15+
"logprobs",
16+
"max_completion_tokens",
17+
"max_tokens",
18+
"messages",
19+
"metadata",
20+
"model",
21+
"modalities",
22+
"n",
23+
"parallel_tool_calls",
24+
"prediction",
25+
"presence_penalty",
26+
"prompt_cache_key",
27+
"prompt_cache_retention",
28+
"reasoning",
29+
"reasoning_effort",
30+
"response_format",
31+
"safety_identifier",
32+
"seed",
33+
"service_tier",
34+
"stop",
35+
"store",
36+
"stream",
37+
"stream_options",
38+
"temperature",
39+
"tool_choice",
40+
"tools",
41+
"top_logprobs",
42+
"top_p",
43+
"user",
44+
"verbosity",
45+
"web_search_options",
46+
}
47+
)
48+
VALID_AZURE_CHAT_COMPLETION_PARAMS: FrozenSet[str] = VALID_OPENAI_CHAT_COMPLETION_PARAMS
49+
VALID_GOOGLE_CHAT_COMPLETION_PARAMS: FrozenSet[str] = VALID_OPENAI_CHAT_COMPLETION_PARAMS
50+
VALID_BEDROCK_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset(
51+
{
52+
*VALID_OPENAI_CHAT_COMPLETION_PARAMS,
53+
"betas",
54+
"container",
55+
"context_management",
56+
"mcp_servers",
57+
"output_config",
58+
"output_format",
59+
"stop_sequences",
60+
"system",
61+
"thinking",
62+
"top_k",
63+
"top_p",
64+
"tool_config",
65+
}
66+
)
67+
68+
def _resolve_valid_chat_completion_params(provider: str | None) -> FrozenSet[str]:
69+
if provider == OPENAI_PROVIDER:
70+
return VALID_OPENAI_CHAT_COMPLETION_PARAMS
71+
elif provider == BEDROCK_PROVIDER:
72+
return VALID_BEDROCK_CHAT_COMPLETION_PARAMS
73+
elif provider == GOOGLE_PROVIDER:
74+
return VALID_GOOGLE_CHAT_COMPLETION_PARAMS
75+
elif provider == AZURE_PROVIDER:
76+
return VALID_AZURE_CHAT_COMPLETION_PARAMS
77+
else:
78+
return VALID_OPENAI_CHAT_COMPLETION_PARAMS
79+
80+
81+
REQUIRED_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset({"messages"})
82+
83+
84+
def _validate_chat_completion_params(params: CompletionParams, provider: str | None) -> None:
85+
"""Ensure only supported chat completion params are passed into inference."""
86+
87+
missing_required = [param for param in REQUIRED_CHAT_COMPLETION_PARAMS if param not in params]
88+
if missing_required:
89+
required_str = ", ".join(sorted(REQUIRED_CHAT_COMPLETION_PARAMS))
90+
raise ValueError(f"openai_args must include the following parameter(s): {required_str}")
91+
92+
valid_params = _resolve_valid_chat_completion_params(provider)
93+
94+
invalid_keys = sorted(set(params.keys()) - valid_params)
95+
if invalid_keys:
96+
raise ValueError(
97+
f"Unsupported chat completion parameter(s) for provider {provider}: "
98+
+ ", ".join(invalid_keys)
99+
)

0 commit comments

Comments
 (0)