Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 198 additions & 12 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Union

from anthropic import AnthropicVertex
from anthropic import AsyncAnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
Expand Down Expand Up @@ -166,7 +167,83 @@ def content_block_to_part(
)
part.function_call.id = content_block.id
return part
raise NotImplementedError("Not supported yet.")

# Handle thinking blocks from Anthropic extended thinking feature
# Thinking blocks have a 'thinking' attribute containing the reasoning text
if hasattr(content_block, "thinking"):
thinking_text = content_block.thinking
logger.info(f"Received thinking block ({len(thinking_text)} chars)")
# Return as Part with thought=True (standard GenAI format)
return types.Part(text=thinking_text, thought=True)

# Alternative check: some versions may use type attribute
if (
hasattr(content_block, "type")
and getattr(content_block, "type", None) == "thinking"
):
thinking_text = str(content_block)
logger.info(
f"Received thinking block via type check ({len(thinking_text)} chars)"
)
# Return as Part with thought=True (standard GenAI format)
return types.Part(text=thinking_text, thought=True)

raise NotImplementedError(
f"Not supported yet: {type(content_block).__name__}"
)


def streaming_event_to_llm_response(
event: anthropic_types.MessageStreamEvent,
) -> Optional[LlmResponse]:
"""Convert Anthropic streaming events to ADK LlmResponse format.

Args:
event: Anthropic streaming event

Returns:
LlmResponse or None if event should be skipped
"""
# Handle content block deltas
if event.type == "content_block_delta":
delta = event.delta

# Text delta
if hasattr(delta, "type") and delta.type == "text_delta":
return LlmResponse(
content=types.Content(
role="model",
parts=[types.Part.from_text(text=delta.text)],
),
partial=True,
)

# Thinking delta
if hasattr(delta, "type") and delta.type == "thinking_delta":
return LlmResponse(
content=types.Content(
role="model",
parts=[types.Part(text=delta.thinking, thought=True)],
),
partial=True,
)

# Handle message deltas (usage updates)
elif event.type == "message_delta":
if hasattr(event, "usage"):
return LlmResponse(
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=getattr(event.usage, "input_tokens", 0),
candidates_token_count=getattr(event.usage, "output_tokens", 0),
total_token_count=(
getattr(event.usage, "input_tokens", 0)
+ getattr(event.usage, "output_tokens", 0)
),
),
)

# Ignore start/stop events
return None


def message_to_generate_content_response(
Expand Down Expand Up @@ -283,19 +360,128 @@ async def generate_content_async(
if llm_request.tools_dict
else NOT_GIVEN
)
# TODO(b/421255973): Enable streaming for anthropic models.
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,

# Extract and convert thinking config from ADK to Anthropic format
thinking = NOT_GIVEN
if llm_request.config and llm_request.config.thinking_config:
budget = llm_request.config.thinking_config.thinking_budget
if budget and budget != 0:
if budget == -1:
# Automatic thinking budget - use recommended default of 10000 tokens
thinking = {"type": "enabled", "budget_tokens": 1024}
logger.info(
"Extended thinking enabled (automatic budget: 10000 tokens)"
)
elif budget > 0:
# Specific budget - enforce minimum 1024 tokens
actual_budget = max(budget, 1024)
thinking = {"type": "enabled", "budget_tokens": actual_budget}
logger.info(
f"Extended thinking enabled (budget: {actual_budget} tokens)"
)

# Determine if streaming should be used
use_streaming = (
stream # From runtime context (streaming_mode == SSE)
or thinking
!= NOT_GIVEN # Extended thinking requires streaming (Anthropic-specific)
or self.max_tokens
>= 8192 # Large max_tokens may exceed 10min timeout (Anthropic SDK requirement)
)
yield message_to_generate_content_response(message)

if use_streaming:
# Use streaming mode
logger.info(
f"Using streaming mode (stream={stream}, "
f"has_thinking={thinking != NOT_GIVEN}, "
f"large_max_tokens={self.max_tokens >= 8192})"
)

# Accumulators for text and thinking
accumulated_text = ""
accumulated_thinking = ""

async with self._anthropic_client.messages.stream(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
thinking=thinking,
) as anthropic_stream:
# Process streaming events
async for event in anthropic_stream:
# Convert Anthropic event to LlmResponse
if llm_response := streaming_event_to_llm_response(event):
# Track accumulated content
is_thought = False
if llm_response.content and llm_response.content.parts:
for part in llm_response.content.parts:
if part.text:
if hasattr(part, "thought") and part.thought:
accumulated_thinking += part.text
is_thought = True
else:
accumulated_text += part.text

# If we have accumulated thinking and now getting text,
# yield the accumulated thinking first
if accumulated_thinking and accumulated_text and not is_thought:
yield LlmResponse(
content=types.Content(
role="model",
parts=[
types.Part(text=accumulated_thinking, thought=True)
],
),
partial=True,
)
accumulated_thinking = "" # Reset after yielding

# Yield partial response (but skip individual thought deltas)
if not is_thought:
yield llm_response

# Get final message to extract usage metadata
final_message = await anthropic_stream.get_final_message()

# Build final aggregated response with complete content
parts = []
if accumulated_thinking:
parts.append(types.Part(text=accumulated_thinking, thought=True))
if accumulated_text:
parts.append(types.Part.from_text(text=accumulated_text))

# Only yield final aggregated response if we have content
if parts:
yield LlmResponse(
content=types.Content(role="model", parts=parts),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=final_message.usage.input_tokens,
candidates_token_count=final_message.usage.output_tokens,
total_token_count=(
final_message.usage.input_tokens
+ final_message.usage.output_tokens
),
),
)

else:
# Non-streaming mode (simple requests without thinking)
logger.info("Using non-streaming mode")
message = await self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
)
yield message_to_generate_content_response(message)

@cached_property
def _anthropic_client(self) -> AnthropicVertex:
def _anthropic_client(self) -> AsyncAnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
Expand All @@ -305,7 +491,7 @@ def _anthropic_client(self) -> AnthropicVertex:
" Anthropic on Vertex."
)

return AnthropicVertex(
return AsyncAnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)
4 changes: 3 additions & 1 deletion tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ async def test_function_declaration_to_tool_param(

@pytest.mark.asyncio
async def test_generate_content_async(
claude_llm, llm_request, generate_content_response, generate_llm_response
llm_request, generate_content_response, generate_llm_response
):
# Use max_tokens < 8192 to trigger non-streaming mode
claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096)
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
with mock.patch.object(
anthropic_llm,
Expand Down
Loading