Skip to content

Commit 4adfd18

Browse files
[Feat]Add support for safety_identifier parameter in chat.completions.create (#14174)
* Add support for safety_identifier parameter in chat.completions.create * make sure param is getting actually passed to the raw api
1 parent 61b2209 commit 4adfd18

File tree

7 files changed

+76
-6
lines changed

7 files changed

+76
-6
lines changed

docs/my-website/docs/completion/input.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def completion(
106106
parallel_tool_calls: Optional[bool] = None,
107107
logprobs: Optional[bool] = None,
108108
top_logprobs: Optional[int] = None,
109+
safety_identifier: Optional[str] = None,
109110
deployment_id=None,
110111
# soon to be deprecated params by OpenAI
111112
functions: Optional[List] = None,
@@ -196,6 +197,8 @@ def completion(
196197

197198
- `top_logprobs`: *int (optional)* - An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to true if this parameter is used.
198199

200+
- `safety_identifier`: *string (optional)* - A unique identifier for tracking and managing safety-related requests. This parameter helps with safety monitoring and compliance tracking.
201+
199202
- `headers`: *dict (optional)* - A dictionary of headers to be sent with the request.
200203

201204
- `extra_headers`: *dict (optional)* - Alternative to `headers`, used to send extra headers in LLM API request.

litellm/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
DEFAULT_SQS_FLUSH_INTERVAL_SECONDS = int(
1515
os.getenv("DEFAULT_SQS_FLUSH_INTERVAL_SECONDS", 10)
1616
)
17-
DEFAULT_NUM_WORKERS_LITELLM_PROXY = int(os.getenv("DEFAULT_NUM_WORKERS_LITELLM_PROXY", 4))
17+
DEFAULT_NUM_WORKERS_LITELLM_PROXY = int(
18+
os.getenv("DEFAULT_NUM_WORKERS_LITELLM_PROXY", 4)
19+
)
1820
DEFAULT_SQS_BATCH_SIZE = int(os.getenv("DEFAULT_SQS_BATCH_SIZE", 512))
1921
SQS_SEND_MESSAGE_ACTION = "SendMessage"
2022
SQS_API_VERSION = "2012-11-05"
@@ -395,6 +397,7 @@
395397
"reasoning_effort": None,
396398
"thinking": None,
397399
"web_search_options": None,
400+
"safety_identifier": None,
398401
}
399402

400403
openai_compatible_endpoints: List = [

litellm/llms/openai/chat/gpt_transformation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def get_supported_openai_params(self, model: str) -> list:
158158
"parallel_tool_calls",
159159
"audio",
160160
"web_search_options",
161+
"safety_identifier",
161162
] # works across all models
162163

163164
model_specific_params = []

litellm/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ async def acompletion(
357357
top_logprobs: Optional[int] = None,
358358
deployment_id=None,
359359
reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]] = None,
360+
safety_identifier: Optional[str] = None,
360361
# set api_base, api_version, api_key
361362
base_url: Optional[str] = None,
362363
api_version: Optional[str] = None,
@@ -493,6 +494,7 @@ async def acompletion(
493494
"api_key": api_key,
494495
"model_list": model_list,
495496
"reasoning_effort": reasoning_effort,
497+
"safety_identifier": safety_identifier,
496498
"extra_headers": extra_headers,
497499
"acompletion": True, # assuming this is a required parameter
498500
"thinking": thinking,
@@ -906,6 +908,7 @@ def completion( # type: ignore # noqa: PLR0915
906908
web_search_options: Optional[OpenAIWebSearchOptions] = None,
907909
deployment_id=None,
908910
extra_headers: Optional[dict] = None,
911+
safety_identifier: Optional[str] = None,
909912
# soon to be deprecated params by OpenAI
910913
functions: Optional[List] = None,
911914
function_call: Optional[str] = None,
@@ -1243,6 +1246,7 @@ def completion( # type: ignore # noqa: PLR0915
12431246
"reasoning_effort": reasoning_effort,
12441247
"thinking": thinking,
12451248
"web_search_options": web_search_options,
1249+
"safety_identifier": safety_identifier,
12461250
"allowed_openai_params": kwargs.get("allowed_openai_params"),
12471251
}
12481252
optional_params = get_optional_params(

litellm/types/llms/openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ class ChatCompletionRequest(TypedDict, total=False):
788788
response_format: dict
789789
seed: int
790790
service_tier: str
791+
safety_identifier: str
791792
stop: Union[str, List[str]]
792793
stream_options: dict
793794
temperature: float

litellm/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -837,15 +837,13 @@ async def _client_async_logging_helper(
837837
# Async Logging Worker
838838
################################################
839839
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
840+
840841
GLOBAL_LOGGING_WORKER.ensure_initialized_and_enqueue(
841-
async_coroutine = logging_obj.async_success_handler(
842-
result=result,
843-
start_time=start_time,
844-
end_time=end_time
842+
async_coroutine=logging_obj.async_success_handler(
843+
result=result, start_time=start_time, end_time=end_time
845844
)
846845
)
847846

848-
849847
################################################
850848
# Sync Logging Worker
851849
################################################
@@ -3304,6 +3302,7 @@ def get_optional_params( # noqa: PLR0915
33043302
messages: Optional[List[AllMessageValues]] = None,
33053303
thinking: Optional[AnthropicThinkingParam] = None,
33063304
web_search_options: Optional[OpenAIWebSearchOptions] = None,
3305+
safety_identifier: Optional[str] = None,
33073306
**kwargs,
33083307
):
33093308
passed_params = locals().copy()

tests/llm_translation/test_openai.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,3 +664,62 @@ async def test_openai_gpt5_reasoning():
664664
)
665665
print("response: ", response)
666666
assert response.choices[0].message.content is not None
667+
668+
669+
@pytest.mark.asyncio
670+
async def test_openai_safety_identifier_parameter():
671+
"""Test that safety_identifier parameter is correctly passed to the OpenAI API."""
672+
from openai import AsyncOpenAI
673+
674+
litellm.set_verbose = True
675+
client = AsyncOpenAI(api_key="fake-api-key")
676+
677+
with patch.object(
678+
client.chat.completions.with_raw_response, "create"
679+
) as mock_client:
680+
try:
681+
await litellm.acompletion(
682+
model="openai/gpt-4o",
683+
messages=[{"role": "user", "content": "Hello, how are you?"}],
684+
safety_identifier="user_code_123456",
685+
client=client,
686+
)
687+
except Exception as e:
688+
print(f"Error: {e}")
689+
690+
mock_client.assert_called_once()
691+
request_body = mock_client.call_args.kwargs
692+
693+
# Verify the request contains the safety_identifier parameter
694+
assert "safety_identifier" in request_body
695+
# Verify safety_identifier is correctly sent to the API
696+
assert request_body["safety_identifier"] == "user_code_123456"
697+
698+
699+
def test_openai_safety_identifier_parameter_sync():
700+
"""Test that safety_identifier parameter is correctly passed to the OpenAI API."""
701+
from openai import OpenAI
702+
703+
litellm.set_verbose = True
704+
client = OpenAI(api_key="fake-api-key")
705+
706+
with patch.object(
707+
client.chat.completions.with_raw_response, "create"
708+
) as mock_client:
709+
try:
710+
litellm.completion(
711+
model="openai/gpt-4o",
712+
messages=[{"role": "user", "content": "Hello, how are you?"}],
713+
safety_identifier="user_code_123456",
714+
client=client,
715+
)
716+
except Exception as e:
717+
print(f"Error: {e}")
718+
719+
mock_client.assert_called_once()
720+
request_body = mock_client.call_args.kwargs
721+
722+
# Verify the request contains the safety_identifier parameter
723+
assert "safety_identifier" in request_body
724+
# Verify safety_identifier is correctly sent to the API
725+
assert request_body["safety_identifier"] == "user_code_123456"

0 commit comments

Comments
 (0)