diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index d8b4d7ce81..d7510c26ad 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -81,6 +81,10 @@ class TextChunk(BaseModel): text: str +class ThoughtChunk(BaseModel): + text: str + + class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int @@ -407,19 +411,19 @@ def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ Tuple[ - Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]], + Optional[Union[TextChunk, ThoughtChunk, FunctionChunk, UsageMetadataChunk]], Optional[str], ], None, None, ]: - """Converts a litellm message to text, function or usage metadata chunk. + """Converts a litellm message to text, thought, function or usage metadata chunk. Args: response: The response from the model. Yields: - A tuple of text or function or usage metadata chunk and finish reason. + A tuple of text, thought, function or usage metadata chunk and finish reason. """ message = None @@ -430,6 +434,9 @@ def _model_response_to_chunk( if message is None and response["choices"][0].get("delta", None): message = response["choices"][0]["delta"] + if message.get("reasoning_content", None): + yield ThoughtChunk(text=message.get("reasoning_content")), finish_reason + if message.get("content", None): yield TextChunk(text=message.get("content")), finish_reason @@ -452,7 +459,9 @@ def _model_response_to_chunk( ), finish_reason if finish_reason and not ( - message.get("content", None) or message.get("tool_calls", None) + message.get("reasoning_content", None) + or message.get("content", None) + or message.get("tool_calls", None) ): yield None, finish_reason @@ -513,6 +522,11 @@ def _message_to_generate_content_response( """ parts = [] + + if reasoning_content := message.get("reasoning_content"): + thought_part = types.Part(text=reasoning_content, thought=True) + parts.append(thought_part) + if message.get("content", None): parts.append(types.Part.from_text(text=message.get("content"))) @@ -830,6 +844,7 @@ async def generate_content_async( if stream: text = "" + reasoning_text = "" # Track function calls by index function_calls = {} # index -> {name, args, id} completion_args["stream"] = True @@ -860,6 +875,15 @@ async def generate_content_async( function_calls[index]["id"] = ( chunk.id or function_calls[index]["id"] or str(index) ) + elif isinstance(chunk, ThoughtChunk): + reasoning_text += chunk.text + yield _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + reasoning_content=chunk.text, + ), + is_partial=True, + ) elif isinstance(chunk, TextChunk): text += chunk.text yield _message_to_generate_content_response( @@ -893,22 +917,30 @@ async def generate_content_async( ), ) ) + message_kwargs = { + "role": "assistant", + "content": text, + "tool_calls": tool_calls, + } + if reasoning_text: + message_kwargs["reasoning_content"] = reasoning_text aggregated_llm_response_with_tool_call = ( _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content=text, - tool_calls=tool_calls, - ) + ChatCompletionAssistantMessage(**message_kwargs) ) ) text = "" + reasoning_text = "" function_calls.clear() - elif finish_reason == "stop" and text: + elif finish_reason == "stop" and (text or reasoning_text): + message_kwargs = {"role": "assistant", "content": text} + if reasoning_text: + message_kwargs["reasoning_content"] = reasoning_text aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage(role="assistant", content=text) + ChatCompletionAssistantMessage(**message_kwargs) ) text = "" + reasoning_text = "" # waiting until streaming ends to yield the llm_response as litellm tends # to send chunk that contains usage_metadata after the chunk with diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 84fd7f26d0..76e2ed2ae5 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -28,6 +28,7 @@ from google.adk.models.lite_llm import LiteLlm from google.adk.models.lite_llm import LiteLLMClient from google.adk.models.lite_llm import TextChunk +from google.adk.models.lite_llm import ThoughtChunk from google.adk.models.lite_llm import UsageMetadataChunk from google.adk.models.llm_request import LlmRequest from google.genai import types @@ -461,6 +462,61 @@ def mock_response(): ] +STREAMING_MODEL_RESPONSE_WITH_REASONING = [ + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + reasoning_content="Let me think", + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + reasoning_content=" step by step...", + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + content="The answer is ", + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + content="42.", + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason="stop", + ) + ] + ), +] + + @pytest.fixture def mock_acompletion(mock_response): return AsyncMock(return_value=mock_response) @@ -1138,6 +1194,28 @@ def test_message_to_generate_content_response_tool_call(): assert response.content.parts[0].function_call.id == "test_tool_call_id" +def test_message_to_generate_content_response_with_reasoning_content(): + message = ChatCompletionAssistantMessage( + role="assistant", + reasoning_content="Thinking step-by-step...", + content="Hello!", + ) + + response = _message_to_generate_content_response(message) + assert response.content.role == "model" + assert len(response.content.parts) == 2 + + # Check that thought part is created first + thought_part = response.content.parts[0] + assert thought_part.text == message["reasoning_content"] + assert thought_part.thought is True + + # Check that regular content follows + text_part = response.content.parts[1] + assert text_part.text == message["content"] + assert text_part.thought is not True + + def test_get_content_text(): parts = [types.Part.from_text(text="Test text")] content = _get_content(parts) @@ -1849,3 +1927,80 @@ def test_non_gemini_litellm_no_warning(): # Test with non-Gemini model LiteLlm(model="openai/gpt-4o") assert len(w) == 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_reasoning_content( + mock_completion, lite_llm_instance +): + """Test streaming with reasoning content (thought) chunks. + + This test verifies that: + 1. Reasoning content chunks are handled correctly in streaming mode + 2. ThoughtChunk objects are yielded for reasoning_content + 3. Thought parts are yielded incrementally with is_partial=True + 4. The final response contains both reasoning and regular content + """ + mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE_WITH_REASONING) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text="What is the meaning of life?")], + ) + ], + ) + + responses = [] + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ): + responses.append(response) + + # Should have 4 partial responses (2 thoughts + 2 text) + 1 final aggregated + assert len(responses) == 5 + + # First response: thought chunk + assert responses[0].content.role == "model" + assert len(responses[0].content.parts) == 1 + assert responses[0].content.parts[0].text == "Let me think" + assert responses[0].content.parts[0].thought is True + assert responses[0].partial is True + + # Second response: thought chunk + assert responses[1].content.role == "model" + assert len(responses[1].content.parts) == 1 + assert responses[1].content.parts[0].text == " step by step..." + assert responses[1].content.parts[0].thought is True + assert responses[1].partial is True + + # Third response: text chunk + assert responses[2].content.role == "model" + assert len(responses[2].content.parts) == 1 + assert responses[2].content.parts[0].text == "The answer is " + assert responses[2].content.parts[0].thought is not True + assert responses[2].partial is True + + # Fourth response: text chunk + assert responses[3].content.role == "model" + assert len(responses[3].content.parts) == 1 + assert responses[3].content.parts[0].text == "42." + assert responses[3].content.parts[0].thought is not True + assert responses[3].partial is True + + # Final aggregated response should contain both reasoning and content + final_response = responses[4] + assert final_response.content.role == "model" + assert len(final_response.content.parts) == 2 + + # First part should be the accumulated thought + assert final_response.content.parts[0].text == "Let me think step by step..." + assert final_response.content.parts[0].thought is True + + # Second part should be the accumulated text content + assert final_response.content.parts[1].text == "The answer is 42." + assert final_response.content.parts[1].thought is not True + + # Final response should not be partial + assert final_response.partial is False