Skip to content

Commit a30c63c

Browse files
jawoszekcopybara-github
authored andcommitted
fix: aclose all async generators to fix OTel tracing context
See #1670 (comment) PiperOrigin-RevId: 794659547
1 parent c5af44c commit a30c63c

23 files changed

+733
-512
lines changed

contributing/samples/telemetry/main.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,19 @@ async def run_prompt(session: Session, new_message: str):
4646
role='user', parts=[types.Part.from_text(text=new_message)]
4747
)
4848
print('** User says:', content.model_dump(exclude_none=True))
49-
async for event in runner.run_async(
49+
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
50+
# no longer supported.
51+
agen = runner.run_async(
5052
user_id=user_id_1,
5153
session_id=session.id,
5254
new_message=content,
53-
):
54-
if event.content.parts and event.content.parts[0].text:
55-
print(f'** {event.author}: {event.content.parts[0].text}')
55+
)
56+
try:
57+
async for event in agen:
58+
if event.content.parts and event.content.parts[0].text:
59+
print(f'** {event.author}: {event.content.parts[0].text}')
60+
finally:
61+
await agen.aclose()
5662

5763
async def run_prompt_bytes(session: Session, new_message: str):
5864
content = types.Content(
@@ -64,14 +70,20 @@ async def run_prompt_bytes(session: Session, new_message: str):
6470
],
6571
)
6672
print('** User says:', content.model_dump(exclude_none=True))
67-
async for event in runner.run_async(
73+
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
74+
# no longer supported.
75+
agen = runner.run_async(
6876
user_id=user_id_1,
6977
session_id=session.id,
7078
new_message=content,
7179
run_config=RunConfig(save_input_blobs_as_artifacts=True),
72-
):
73-
if event.content.parts and event.content.parts[0].text:
74-
print(f'** {event.author}: {event.content.parts[0].text}')
80+
)
81+
try:
82+
async for event in agen:
83+
if event.content.parts and event.content.parts[0].text:
84+
print(f'** {event.author}: {event.content.parts[0].text}')
85+
finally:
86+
await agen.aclose()
7587

7688
start_time = time.time()
7789
print('Start time:', start_time)

src/google/adk/a2a/executor/a2a_agent_executor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from typing import Optional
2525
import uuid
2626

27+
from ...utils.context_utils import Aclosing
28+
2729
try:
2830
from a2a.server.agent_execution import AgentExecutor
2931
from a2a.server.agent_execution.context import RequestContext
@@ -212,12 +214,13 @@ async def _handle_request(
212214
)
213215

214216
task_result_aggregator = TaskResultAggregator()
215-
async for adk_event in runner.run_async(**run_args):
216-
for a2a_event in convert_event_to_a2a_events(
217-
adk_event, invocation_context, context.task_id, context.context_id
218-
):
219-
task_result_aggregator.process_event(a2a_event)
220-
await event_queue.enqueue_event(a2a_event)
217+
async with Aclosing(runner.run_async(**run_args)) as agen:
218+
async for adk_event in agen:
219+
for a2a_event in convert_event_to_a2a_events(
220+
adk_event, invocation_context, context.task_id, context.context_id
221+
):
222+
task_result_aggregator.process_event(a2a_event)
223+
await event_queue.enqueue_event(a2a_event)
221224

222225
# publish the task result event - this is final
223226
if (

src/google/adk/agents/base_agent.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from typing_extensions import TypeAlias
4040

4141
from ..events.event import Event
42+
from ..utils.context_utils import Aclosing
4243
from ..utils.feature_decorator import experimental
4344
from .base_agent_config import BaseAgentConfig
4445
from .callback_context import CallbackContext
@@ -212,21 +213,27 @@ async def run_async(
212213
Event: the events generated by the agent.
213214
"""
214215

215-
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
216-
ctx = self._create_invocation_context(parent_context)
216+
async def _run_with_trace() -> AsyncGenerator[Event, None]:
217+
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
218+
ctx = self._create_invocation_context(parent_context)
217219

218-
if event := await self.__handle_before_agent_callback(ctx):
219-
yield event
220-
if ctx.end_invocation:
221-
return
220+
if event := await self.__handle_before_agent_callback(ctx):
221+
yield event
222+
if ctx.end_invocation:
223+
return
222224

223-
async for event in self._run_async_impl(ctx):
224-
yield event
225+
async with Aclosing(self._run_async_impl(ctx)) as agen:
226+
async for event in agen:
227+
yield event
225228

226-
if ctx.end_invocation:
227-
return
229+
if ctx.end_invocation:
230+
return
228231

229-
if event := await self.__handle_after_agent_callback(ctx):
232+
if event := await self.__handle_after_agent_callback(ctx):
233+
yield event
234+
235+
async with Aclosing(_run_with_trace()) as agen:
236+
async for event in agen:
230237
yield event
231238

232239
@final
@@ -243,18 +250,25 @@ async def run_live(
243250
Yields:
244251
Event: the events generated by the agent.
245252
"""
246-
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
247-
ctx = self._create_invocation_context(parent_context)
248253

249-
if event := await self.__handle_before_agent_callback(ctx):
250-
yield event
251-
if ctx.end_invocation:
252-
return
254+
async def _run_with_trace() -> AsyncGenerator[Event, None]:
255+
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
256+
ctx = self._create_invocation_context(parent_context)
253257

254-
async for event in self._run_live_impl(ctx):
255-
yield event
258+
if event := await self.__handle_before_agent_callback(ctx):
259+
yield event
260+
if ctx.end_invocation:
261+
return
262+
263+
async with Aclosing(self._run_live_impl(ctx)) as agen:
264+
async for event in agen:
265+
yield event
266+
267+
if event := await self.__handle_after_agent_callback(ctx):
268+
yield event
256269

257-
if event := await self.__handle_after_agent_callback(ctx):
270+
async with Aclosing(_run_with_trace()) as agen:
271+
async for event in agen:
258272
yield event
259273

260274
async def _run_async_impl(

src/google/adk/agents/llm_agent.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from ..tools.function_tool import FunctionTool
5252
from ..tools.tool_configs import ToolConfig
5353
from ..tools.tool_context import ToolContext
54+
from ..utils.context_utils import Aclosing
5455
from ..utils.feature_decorator import experimental
5556
from .base_agent import BaseAgent
5657
from .base_agent_config import BaseAgentConfig
@@ -283,19 +284,21 @@ class LlmAgent(BaseAgent):
283284
async def _run_async_impl(
284285
self, ctx: InvocationContext
285286
) -> AsyncGenerator[Event, None]:
286-
async for event in self._llm_flow.run_async(ctx):
287-
self.__maybe_save_output_to_state(event)
288-
yield event
287+
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
288+
async for event in agen:
289+
self.__maybe_save_output_to_state(event)
290+
yield event
289291

290292
@override
291293
async def _run_live_impl(
292294
self, ctx: InvocationContext
293295
) -> AsyncGenerator[Event, None]:
294-
async for event in self._llm_flow.run_live(ctx):
295-
self.__maybe_save_output_to_state(event)
296-
yield event
297-
if ctx.end_invocation:
298-
return
296+
async with Aclosing(self._llm_flow.run_live(ctx)) as agen:
297+
async for event in agen:
298+
self.__maybe_save_output_to_state(event)
299+
yield event
300+
if ctx.end_invocation:
301+
return
299302

300303
@property
301304
def canonical_model(self) -> BaseLlm:

src/google/adk/agents/loop_agent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from ..agents.invocation_context import InvocationContext
2929
from ..events.event import Event
30+
from ..utils.context_utils import Aclosing
3031
from ..utils.feature_decorator import experimental
3132
from .base_agent import BaseAgent
3233
from .base_agent_config import BaseAgentConfig
@@ -58,10 +59,11 @@ async def _run_async_impl(
5859
while not self.max_iterations or times_looped < self.max_iterations:
5960
for sub_agent in self.sub_agents:
6061
should_exit = False
61-
async for event in sub_agent.run_async(ctx):
62-
yield event
63-
if event.actions.escalate:
64-
should_exit = True
62+
async with Aclosing(sub_agent.run_async(ctx)) as agen:
63+
async for event in agen:
64+
yield event
65+
if event.actions.escalate:
66+
should_exit = True
6567

6668
if should_exit:
6769
return

src/google/adk/agents/parallel_agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing_extensions import override
2727

2828
from ..events.event import Event
29+
from ..utils.context_utils import Aclosing
2930
from .base_agent import BaseAgent
3031
from .base_agent_config import BaseAgentConfig
3132
from .invocation_context import InvocationContext
@@ -111,8 +112,10 @@ async def _run_async_impl(
111112
)
112113
for sub_agent in self.sub_agents
113114
]
114-
async for event in _merge_agent_run(agent_runs):
115-
yield event
115+
116+
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
117+
async for event in agen:
118+
yield event
116119

117120
@override
118121
async def _run_live_impl(

src/google/adk/agents/sequential_agent.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing_extensions import override
2323

2424
from ..events.event import Event
25+
from ..utils.context_utils import Aclosing
2526
from .base_agent import BaseAgent
2627
from .base_agent import BaseAgentConfig
2728
from .invocation_context import InvocationContext
@@ -40,8 +41,9 @@ async def _run_async_impl(
4041
self, ctx: InvocationContext
4142
) -> AsyncGenerator[Event, None]:
4243
for sub_agent in self.sub_agents:
43-
async for event in sub_agent.run_async(ctx):
44-
yield event
44+
async with Aclosing(sub_agent.run_async(ctx)) as agen:
45+
async for event in agen:
46+
yield event
4547

4648
@override
4749
async def _run_live_impl(
@@ -78,5 +80,6 @@ def task_completed():
7880
do not generate any text other than the function call."""
7981

8082
for sub_agent in self.sub_agents:
81-
async for event in sub_agent.run_live(ctx):
82-
yield event
83+
async with Aclosing(sub_agent.run_live(ctx)) as agen:
84+
async for event in agen:
85+
yield event

src/google/adk/cli/adk_web_server.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from ..runners import Runner
7474
from ..sessions.base_session_service import BaseSessionService
7575
from ..sessions.session import Session
76+
from ..utils.context_utils import Aclosing
7677
from .cli_eval import EVAL_SESSION_ID_PREFIX
7778
from .cli_eval import EvalStatus
7879
from .utils import cleanup
@@ -828,14 +829,16 @@ async def agent_run(req: AgentRunRequest) -> list[Event]:
828829
if not session:
829830
raise HTTPException(status_code=404, detail="Session not found")
830831
runner = await self.get_runner_async(req.app_name)
831-
events = [
832-
event
833-
async for event in runner.run_async(
832+
833+
events = []
834+
async with Aclosing(
835+
runner.run_async(
834836
user_id=req.user_id,
835837
session_id=req.session_id,
836838
new_message=req.new_message,
837839
)
838-
]
840+
) as agen:
841+
events = [event async for event in agen]
839842
logger.info("Generated %s events in agent run", len(events))
840843
logger.debug("Events generated: %s", events)
841844
return events
@@ -856,19 +859,24 @@ async def event_generator():
856859
StreamingMode.SSE if req.streaming else StreamingMode.NONE
857860
)
858861
runner = await self.get_runner_async(req.app_name)
859-
async for event in runner.run_async(
860-
user_id=req.user_id,
861-
session_id=req.session_id,
862-
new_message=req.new_message,
863-
state_delta=req.state_delta,
864-
run_config=RunConfig(streaming_mode=stream_mode),
865-
):
866-
# Format as SSE data
867-
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
868-
logger.debug(
869-
"Generated event in agent run streaming: %s", sse_event
870-
)
871-
yield f"data: {sse_event}\n\n"
862+
async with Aclosing(
863+
runner.run_async(
864+
user_id=req.user_id,
865+
session_id=req.session_id,
866+
new_message=req.new_message,
867+
state_delta=req.state_delta,
868+
run_config=RunConfig(streaming_mode=stream_mode),
869+
)
870+
) as agen:
871+
async for event in agen:
872+
# Format as SSE data
873+
sse_event = event.model_dump_json(
874+
exclude_none=True, by_alias=True
875+
)
876+
logger.debug(
877+
"Generated event in agent run streaming: %s", sse_event
878+
)
879+
yield f"data: {sse_event}\n\n"
872880
except Exception as e:
873881
logger.exception("Error in event_generator: %s", e)
874882
# You might want to yield an error event here
@@ -954,12 +962,15 @@ async def agent_live_run(
954962

955963
async def forward_events():
956964
runner = await self.get_runner_async(app_name)
957-
async for event in runner.run_live(
958-
session=session, live_request_queue=live_request_queue
959-
):
960-
await websocket.send_text(
961-
event.model_dump_json(exclude_none=True, by_alias=True)
962-
)
965+
async with Aclosing(
966+
runner.run_live(
967+
session=session, live_request_queue=live_request_queue
968+
)
969+
) as agen:
970+
async for event in agen:
971+
await websocket.send_text(
972+
event.model_dump_json(exclude_none=True, by_alias=True)
973+
)
963974

964975
async def process_messages():
965976
try:

0 commit comments

Comments
 (0)