diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index d8b4d7ce81..080bf30b50 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -818,7 +818,7 @@ async def generate_content_async( tools = None completion_args = { - "model": self.model, + "model": llm_request.model or self.model, "messages": messages, "tools": tools, "response_format": response_format, diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 84fd7f26d0..a47506b71d 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -548,6 +548,53 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance): ) +@pytest.mark.asyncio +async def test_generate_content_async_with_model_override( + mock_acompletion, lite_llm_instance +): + llm_request = LlmRequest( + model="overridden_model", + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + ) + + async for response in lite_llm_instance.generate_content_async(llm_request): + assert response.content.role == "model" + assert response.content.parts[0].text == "Test response" + + mock_acompletion.assert_called_once() + + _, kwargs = mock_acompletion.call_args + assert kwargs["model"] == "overridden_model" + assert kwargs["messages"][0]["role"] == "user" + assert kwargs["messages"][0]["content"] == "Test prompt" + + +@pytest.mark.asyncio +async def test_generate_content_async_without_model_override( + mock_acompletion, lite_llm_instance +): + llm_request = LlmRequest( + model=None, + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + ) + + async for response in lite_llm_instance.generate_content_async(llm_request): + assert response.content.role == "model" + + mock_acompletion.assert_called_once() + + _, kwargs = mock_acompletion.call_args + assert kwargs["model"] == "test_model" + + litellm_append_user_content_test_cases = [ pytest.param( LlmRequest(