Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ async def generate_content_async(
tools = None

completion_args = {
"model": self.model,
"model": llm_request.model or self.model,
"messages": messages,
"tools": tools,
"response_format": response_format,
Expand Down
67 changes: 67 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,73 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance):
)


@pytest.mark.asyncio
async def test_generate_content_async_with_model_override(
mock_acompletion, lite_llm_instance
):
llm_request = LlmRequest(
model="overridden_model",
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="test_function",
description="Test function description",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"test_arg": types.Schema(
type=types.Type.STRING
)
},
),
)
]
)
],
),
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"

mock_acompletion.assert_called_once()

_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "overridden_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"


@pytest.mark.asyncio
async def test_generate_content_async_without_model_override(
mock_acompletion, lite_llm_instance
):
llm_request = LlmRequest(
model=None,
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"

mock_acompletion.assert_called_once()

_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"


litellm_append_user_content_test_cases = [
pytest.param(
LlmRequest(
Expand Down