Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions src/llama_stack/providers/utils/inference/litellm_openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def _add_additional_properties_recursive(self, schema):
return schema

async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
from typing import Any

input_dict: dict[str, Any] = {}

input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
Expand All @@ -139,30 +141,27 @@ async def _get_params(self, request: ChatCompletionRequest) -> dict:
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)

fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False

# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
fmt_dict = self._add_additional_properties_recursive(fmt_dict)

input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice

return {
"model": request.model,
Expand All @@ -176,10 +175,10 @@ async def _get_params(self, request: ChatCompletionRequest) -> dict:
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection

api_key = self.api_key_from_config
if not api_key:
raise ValueError(
"API key is not set. Please provide a valid API key in the "
Expand All @@ -192,15 +191,20 @@ async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
if not self.model_store:
raise ValueError("Model store is not initialized")

model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model

# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input

# Call litellm embedding function
# litellm.drop_params = True
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
Expand All @@ -217,18 +221,23 @@ async def openai_embeddings(

return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
model=provider_resource_id,
usage=usage,
)

async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
if not self.model_store:
raise ValueError("Model store is not initialized")

model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model

request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
Expand All @@ -249,7 +258,8 @@ async def openai_completion(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs

async def openai_chat_completion(
self,
Expand All @@ -265,10 +275,15 @@ async def openai_chat_completion(
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}

if not self.model_store:
raise ValueError("Model store is not initialized")

model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model

request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
Expand All @@ -294,7 +309,8 @@ async def openai_chat_completion(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs

async def check_model_availability(self, model: str) -> bool:
"""
Expand Down
Loading