Skip to content

Commit 5628846

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Updated some flows to run on BaseAgent instances where supported
PiperOrigin-RevId: 794143338
1 parent c52f956 commit 5628846

File tree

6 files changed

+62
-34
lines changed

6 files changed

+62
-34
lines changed

src/google/adk/agents/readonly_context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,23 @@
1414

1515
from __future__ import annotations
1616

17+
import copy
1718
from types import MappingProxyType
1819
from typing import Any
20+
from typing import List
1921
from typing import Optional
2022
from typing import TYPE_CHECKING
23+
from typing import TypeVar
2124

2225
if TYPE_CHECKING:
2326
from google.genai import types
2427

2528
from .invocation_context import InvocationContext
2629

2730

31+
Event = TypeVar('Event')
32+
33+
2834
class ReadonlyContext:
2935

3036
def __init__(
@@ -52,3 +58,8 @@ def agent_name(self) -> str:
5258
def state(self) -> MappingProxyType[str, Any]:
5359
"""The state of the current session. READONLY field."""
5460
return MappingProxyType(self._invocation_context.session.state)
61+
62+
@property
63+
def events(self) -> List[Event]:
64+
"""Historical events from the current session."""
65+
return copy.deepcopy(self._invocation_context.session.events)

src/google/adk/auth/auth_preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def run_async(
4444
from ..agents.llm_agent import LlmAgent
4545

4646
agent = invocation_context.agent
47-
if not isinstance(agent, LlmAgent):
47+
if not hasattr(agent, 'canonical_tools'):
4848
return
4949
events = invocation_context.session.events
5050
if not events:
@@ -110,7 +110,7 @@ async def run_async(
110110
event,
111111
{
112112
tool.name: tool
113-
for tool in await agent.canonical_tools(
113+
for tool in await getattr(agent, 'canonical_tools')(
114114
ReadonlyContext(invocation_context)
115115
)
116116
},

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,16 @@ async def _preprocess_async(
378378
from ...agents.llm_agent import LlmAgent
379379

380380
agent = invocation_context.agent
381-
if not isinstance(agent, LlmAgent):
382-
return
383381

384382
# Runs processors.
385383
for processor in self.request_processors:
384+
logging.debug(f'Running processor: {type(processor).__name__}')
386385
async for event in processor.run_async(invocation_context, llm_request):
387386
yield event
388387

388+
if not isinstance(agent, LlmAgent):
389+
return
390+
389391
# Run processors for tools.
390392
for tool_union in agent.tools:
391393
tool_context = ToolContext(invocation_context)

src/google/adk/flows/llm_flows/basic.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from __future__ import annotations
1818

1919
from typing import AsyncGenerator
20-
from typing import Generator
2120

2221
from google.genai import types
2322
from typing_extensions import override
@@ -36,26 +35,7 @@ async def run_async(
3635
) -> AsyncGenerator[Event, None]:
3736
from ...agents.llm_agent import LlmAgent
3837

39-
agent = invocation_context.agent
40-
if not isinstance(agent, LlmAgent):
41-
return
42-
43-
llm_request.model = (
44-
agent.canonical_model
45-
if isinstance(agent.canonical_model, str)
46-
else agent.canonical_model.model
47-
)
48-
llm_request.config = (
49-
agent.generate_content_config.model_copy(deep=True)
50-
if agent.generate_content_config
51-
else types.GenerateContentConfig()
52-
)
53-
# Only set output_schema if no tools are specified. as of now, model don't
54-
# support output_schema and tools together. we have a workaround to support
55-
# both outoput_schema and tools at the same time. see
56-
# _output_schema_processor.py for details
57-
if agent.output_schema and not agent.tools:
58-
llm_request.set_output_schema(agent.output_schema)
38+
llm_request.config = types.GenerateContentConfig()
5939

6040
llm_request.live_connect_config.response_modalities = (
6141
invocation_context.run_config.response_modalities
@@ -81,6 +61,23 @@ async def run_async(
8161
llm_request.live_connect_config.session_resumption = (
8262
invocation_context.run_config.session_resumption
8363
)
64+
agent = invocation_context.agent
65+
if not isinstance(agent, LlmAgent):
66+
return
67+
68+
llm_request.model = (
69+
agent.canonical_model
70+
if isinstance(agent.canonical_model, str)
71+
else agent.canonical_model.model
72+
)
73+
# Only set output_schema if no tools are specified. as of now, model don't
74+
# support output_schema and tools together. we have a workaround to support
75+
# both outoput_schema and tools at the same time. see
76+
# _output_schema_processor.py for details
77+
if agent.output_schema and not agent.tools:
78+
llm_request.set_output_schema(agent.output_schema)
79+
if agent.generate_content_config:
80+
llm_request.config = agent.generate_content_config.model_copy(deep=True)
8481

8582
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
8683

src/google/adk/flows/llm_flows/contents.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,24 @@
3333
class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
3434
"""Builds the contents for the LLM request."""
3535

36+
def __init__(self, mutating: bool):
37+
self.mutating = mutating
38+
3639
@override
3740
async def run_async(
3841
self, invocation_context: InvocationContext, llm_request: LlmRequest
3942
) -> AsyncGenerator[Event, None]:
4043
from ...agents.llm_agent import LlmAgent
4144

4245
agent = invocation_context.agent
43-
if not isinstance(agent, LlmAgent):
44-
return
4546

46-
if agent.include_contents == 'default':
47+
if not isinstance(agent, LlmAgent) or agent.include_contents == 'default':
4748
# Include full conversation history
4849
llm_request.contents = _get_contents(
4950
invocation_context.branch,
5051
invocation_context.session.events,
5152
agent.name,
53+
self.mutating,
5254
)
5355
else:
5456
# Include current turn context only (no conversation history)
@@ -63,7 +65,8 @@ async def run_async(
6365
yield # This is a no-op but maintains generator structure
6466

6567

66-
request_processor = _ContentLlmRequestProcessor()
68+
request_processor = _ContentLlmRequestProcessor(mutating=True)
69+
non_mutating_request_processor = _ContentLlmRequestProcessor(mutating=False)
6770

6871

6972
def _rearrange_events_for_async_function_responses_in_history(
@@ -203,7 +206,10 @@ def _rearrange_events_for_latest_function_response(
203206

204207

205208
def _get_contents(
206-
current_branch: Optional[str], events: list[Event], agent_name: str = ''
209+
current_branch: Optional[str],
210+
events: list[Event],
211+
agent_name: str = '',
212+
mutating: bool = True,
207213
) -> list[types.Content]:
208214
"""Get the contents for the LLM request.
209215
@@ -213,6 +219,7 @@ def _get_contents(
213219
current_branch: The current branch of the agent.
214220
events: Events to process.
215221
agent_name: The name of the agent.
222+
mutating: Whether to rewrite all conversation turns as user verbalizations.
216223
217224
Returns:
218225
A list of processed contents.
@@ -240,7 +247,7 @@ def _get_contents(
240247
continue
241248
filtered_events.append(
242249
_convert_foreign_event(event)
243-
if _is_other_agent_reply(agent_name, event)
250+
if mutating and _is_other_agent_reply(agent_name, event)
244251
else event
245252
)
246253

@@ -313,7 +320,6 @@ def _convert_foreign_event(event: Event) -> Event:
313320
314321
Returns:
315322
The converted event.
316-
317323
"""
318324
if not event.content or not event.content.parts:
319325
return event

src/google/adk/tools/agent_tool.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,19 @@ class AgentTool(BaseTool):
4646
Attributes:
4747
agent: The agent to wrap.
4848
skip_summarization: Whether to skip summarization of the agent output.
49+
include_conversational_context: Whether to pass conversation history through
50+
to the sub-agent.
4951
"""
5052

51-
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
53+
def __init__(
54+
self,
55+
agent: BaseAgent,
56+
skip_summarization: bool = False,
57+
include_conversational_context: bool = False,
58+
):
5259
self.agent = agent
53-
self.skip_summarization: bool = skip_summarization
60+
self.skip_summarization = skip_summarization
61+
self.include_conversational_context = include_conversational_context
5462

5563
super().__init__(name=agent.name, description=agent.description)
5664

@@ -140,6 +148,10 @@ async def run_async(
140148
state=tool_context.state.to_dict(),
141149
)
142150

151+
if self.include_conversational_context:
152+
for event in tool_context.events[:-1]:
153+
await runner.session_service.append_event(session, event)
154+
143155
last_event = None
144156
async for event in runner.run_async(
145157
user_id=session.user_id, session_id=session.id, new_message=content

0 commit comments

Comments
 (0)