Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class TextChunk(BaseModel):
text: str


class ThoughtChunk(BaseModel):
text: str


class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
Expand Down Expand Up @@ -407,19 +411,19 @@ def _model_response_to_chunk(
response: ModelResponse,
) -> Generator[
Tuple[
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
Optional[Union[TextChunk, ThoughtChunk, FunctionChunk, UsageMetadataChunk]],
Optional[str],
],
None,
None,
]:
"""Converts a litellm message to text, function or usage metadata chunk.
"""Converts a litellm message to text, thought, function or usage metadata chunk.

Args:
response: The response from the model.

Yields:
A tuple of text or function or usage metadata chunk and finish reason.
A tuple of text, thought, function or usage metadata chunk and finish reason.
"""

message = None
Expand All @@ -430,6 +434,9 @@ def _model_response_to_chunk(
if message is None and response["choices"][0].get("delta", None):
message = response["choices"][0]["delta"]

if message.get("reasoning_content", None):
yield ThoughtChunk(text=message.get("reasoning_content")), finish_reason

if message.get("content", None):
yield TextChunk(text=message.get("content")), finish_reason

Expand All @@ -452,7 +459,9 @@ def _model_response_to_chunk(
), finish_reason

if finish_reason and not (
message.get("content", None) or message.get("tool_calls", None)
message.get("reasoning_content", None)
or message.get("content", None)
or message.get("tool_calls", None)
):
yield None, finish_reason

Expand Down Expand Up @@ -513,6 +522,11 @@ def _message_to_generate_content_response(
"""

parts = []

if reasoning_content := message.get("reasoning_content"):
thought_part = types.Part(text=reasoning_content, thought=True)
parts.append(thought_part)

if message.get("content", None):
parts.append(types.Part.from_text(text=message.get("content")))

Expand Down Expand Up @@ -830,6 +844,7 @@ async def generate_content_async(

if stream:
text = ""
reasoning_text = ""
# Track function calls by index
function_calls = {} # index -> {name, args, id}
completion_args["stream"] = True
Expand Down Expand Up @@ -860,6 +875,15 @@ async def generate_content_async(
function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"] or str(index)
)
elif isinstance(chunk, ThoughtChunk):
reasoning_text += chunk.text
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
reasoning_content=chunk.text,
),
is_partial=True,
)
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
Expand Down Expand Up @@ -893,22 +917,30 @@ async def generate_content_async(
),
)
)
message_kwargs = {
"role": "assistant",
"content": text,
"tool_calls": tool_calls,
}
if reasoning_text:
message_kwargs["reasoning_content"] = reasoning_text
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=text,
tool_calls=tool_calls,
)
ChatCompletionAssistantMessage(**message_kwargs)
)
)
text = ""
reasoning_text = ""
function_calls.clear()
elif finish_reason == "stop" and text:
elif finish_reason == "stop" and (text or reasoning_text):
message_kwargs = {"role": "assistant", "content": text}
if reasoning_text:
message_kwargs["reasoning_content"] = reasoning_text
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
ChatCompletionAssistantMessage(**message_kwargs)
)
text = ""
reasoning_text = ""

# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
Expand Down
155 changes: 155 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import ThoughtChunk
from google.adk.models.lite_llm import UsageMetadataChunk
from google.adk.models.llm_request import LlmRequest
from google.genai import types
Expand Down Expand Up @@ -461,6 +462,61 @@ def mock_response():
]


STREAMING_MODEL_RESPONSE_WITH_REASONING = [
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
reasoning_content="Let me think",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
reasoning_content=" step by step...",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
content="The answer is ",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
content="42.",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason="stop",
)
]
),
]


@pytest.fixture
def mock_acompletion(mock_response):
return AsyncMock(return_value=mock_response)
Expand Down Expand Up @@ -1138,6 +1194,28 @@ def test_message_to_generate_content_response_tool_call():
assert response.content.parts[0].function_call.id == "test_tool_call_id"


def test_message_to_generate_content_response_with_reasoning_content():
message = ChatCompletionAssistantMessage(
role="assistant",
reasoning_content="Thinking step-by-step...",
content="Hello!",
)

response = _message_to_generate_content_response(message)
assert response.content.role == "model"
assert len(response.content.parts) == 2

# Check that thought part is created first
thought_part = response.content.parts[0]
assert thought_part.text == message["reasoning_content"]
assert thought_part.thought is True

# Check that regular content follows
text_part = response.content.parts[1]
assert text_part.text == message["content"]
assert text_part.thought is not True


def test_get_content_text():
parts = [types.Part.from_text(text="Test text")]
content = _get_content(parts)
Expand Down Expand Up @@ -1849,3 +1927,80 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.asyncio
async def test_generate_content_async_stream_with_reasoning_content(
mock_completion, lite_llm_instance
):
"""Test streaming with reasoning content (thought) chunks.

This test verifies that:
1. Reasoning content chunks are handled correctly in streaming mode
2. ThoughtChunk objects are yielded for reasoning_content
3. Thought parts are yielded incrementally with is_partial=True
4. The final response contains both reasoning and regular content
"""
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE_WITH_REASONING)

llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[types.Part.from_text(text="What is the meaning of life?")],
)
],
)

responses = []
async for response in lite_llm_instance.generate_content_async(
llm_request, stream=True
):
responses.append(response)

# Should have 4 partial responses (2 thoughts + 2 text) + 1 final aggregated
assert len(responses) == 5

# First response: thought chunk
assert responses[0].content.role == "model"
assert len(responses[0].content.parts) == 1
assert responses[0].content.parts[0].text == "Let me think"
assert responses[0].content.parts[0].thought is True
assert responses[0].partial is True

# Second response: thought chunk
assert responses[1].content.role == "model"
assert len(responses[1].content.parts) == 1
assert responses[1].content.parts[0].text == " step by step..."
assert responses[1].content.parts[0].thought is True
assert responses[1].partial is True

# Third response: text chunk
assert responses[2].content.role == "model"
assert len(responses[2].content.parts) == 1
assert responses[2].content.parts[0].text == "The answer is "
assert responses[2].content.parts[0].thought is not True
assert responses[2].partial is True

# Fourth response: text chunk
assert responses[3].content.role == "model"
assert len(responses[3].content.parts) == 1
assert responses[3].content.parts[0].text == "42."
assert responses[3].content.parts[0].thought is not True
assert responses[3].partial is True

# Final aggregated response should contain both reasoning and content
final_response = responses[4]
assert final_response.content.role == "model"
assert len(final_response.content.parts) == 2

# First part should be the accumulated thought
assert final_response.content.parts[0].text == "Let me think step by step..."
assert final_response.content.parts[0].thought is True

# Second part should be the accumulated text content
assert final_response.content.parts[1].text == "The answer is 42."
assert final_response.content.parts[1].thought is not True

# Final response should not be partial
assert final_response.partial is False