Skip to content

Commit 0671da7

Browse files
committed
Merge branch 'jscpd-fixes' of https://github.com/google/a2a-python into jscpd-fixes
2 parents 6a15519 + 5b608ce commit 0671da7

File tree

1 file changed

+72
-23
lines changed

1 file changed

+72
-23
lines changed

examples/google_adk/calendar_agent/adk_agent_executor.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from google.adk import Runner
1010
from google.adk.auth import AuthConfig
1111
from google.adk.events import Event
12+
from google.adk.session_service import (
13+
Session,
14+
)
1215
from google.genai import types
1316

1417
from a2a.server.agent_execution import AgentExecutor
@@ -45,6 +48,7 @@ class ADKAgentExecutor(AgentExecutor):
4548
"""An AgentExecutor that runs an ADK-based Agent."""
4649

4750
_awaiting_auth: dict[str, asyncio.Future]
51+
_running_sessions: dict[str, Session]
4852

4953
def __init__(self, runner: Runner, card: AgentCard):
5054
self.runner = runner
@@ -98,12 +102,17 @@ async def _process_request(
98102
# is received.
99103
break
100104
if event.is_final_response():
101-
parts = convert_genai_parts_to_a2a(event.content.parts)
102-
logger.debug('Yielding final response: %s', parts)
103-
task_updater.add_artifact(parts)
104-
task_updater.complete()
105+
if event.content and event.content.parts:
106+
parts = convert_genai_parts_to_a2a(event.content.parts)
107+
logger.debug('Yielding final response: %s', parts)
108+
task_updater.add_artifact(parts)
109+
task_updater.complete()
105110
break
106-
if not event.get_function_calls():
111+
if (
112+
not event.get_function_calls()
113+
and event.content
114+
and event.content.parts
115+
):
107116
logger.debug('Yielding update response')
108117
task_updater.update_status(
109118
TaskState.working,
@@ -208,11 +217,29 @@ async def execute(
208217
event_queue: EventQueue,
209218
):
210219
# Run the agent until either complete or the task is suspended.
220+
assert context.task_id is not None, (
221+
'Task ID must be present for execution'
222+
)
223+
assert context.context_id is not None, (
224+
'Context ID must be present for execution'
225+
)
226+
211227
updater = TaskUpdater(event_queue, context.task_id, context.context_id)
212228
# Immediately notify that the task is submitted.
213229
if not context.current_task:
214230
updater.submit()
215231
updater.start_work()
232+
233+
if context.message is None:
234+
logger.warning('Execute called with no message in context')
235+
updater.update_status(
236+
TaskState.failed,
237+
message=new_agent_text_message(
238+
'No message provided to the agent.'
239+
),
240+
)
241+
return
242+
216243
await self._process_request(
217244
types.UserContent(
218245
parts=convert_a2a_parts_to_genai(context.message.parts),
@@ -229,12 +256,17 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue):
229256
async def on_auth_callback(self, state: str, uri: str):
230257
self._awaiting_auth[state].set_result(uri)
231258

232-
def _upsert_session(self, session_id: str):
233-
return self.runner.session_service.get_session(
234-
app_name=self.runner.app_name, user_id='self', session_id=session_id
235-
) or self.runner.session_service.create_session(
259+
def _upsert_session(self, session_id: str) -> Session:
260+
session = self.runner.session_service.get_session(
236261
app_name=self.runner.app_name, user_id='self', session_id=session_id
237262
)
263+
if session is None:
264+
session = self.runner.session_service.create_session(
265+
app_name=self.runner.app_name,
266+
user_id='self',
267+
session_id=session_id,
268+
)
269+
return session
238270

239271

240272
def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]:
@@ -244,32 +276,43 @@ def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]:
244276

245277
def convert_a2a_part_to_genai(part: Part) -> types.Part:
246278
"""Convert a single A2A Part type into a Google Gen AI Part type."""
247-
part = part.root
248-
if isinstance(part, TextPart):
249-
return types.Part(text=part.text)
250-
if isinstance(part, FilePart):
251-
if isinstance(part.file, FileWithUri):
279+
if isinstance(part.root, TextPart):
280+
return types.Part(text=part.root.text)
281+
if isinstance(part.root, FilePart):
282+
file_data = part.root.file
283+
if isinstance(file_data, FileWithUri):
252284
return types.Part(
253285
file_data=types.FileData(
254-
file_uri=part.file.uri, mime_type=part.file.mime_type
286+
file_uri=file_data.uri,
287+
mime_type=file_data.mime_type,
255288
)
256289
)
257-
if isinstance(part.file, FileWithBytes):
290+
if isinstance(file_data, FileWithBytes):
258291
return types.Part(
259292
inline_data=types.Blob(
260-
data=part.file.bytes, mime_type=part.file.mime_type
293+
data=file_data.bytes,
294+
mime_type=file_data.mime_type,
261295
)
262296
)
263-
raise ValueError(f'Unsupported file type: {type(part.file)}')
264-
raise ValueError(f'Unsupported part type: {type(part)}')
297+
raise ValueError(f'Unsupported file type: {type(file_data)}')
298+
raise ValueError(f'Unsupported part root type: {type(part.root)}')
265299

266300

267301
def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]:
268302
"""Convert a list of Google Gen AI Part types into a list of A2A Part types."""
303+
if not parts:
304+
return []
269305
return [
270306
convert_genai_part_to_a2a(part)
271307
for part in parts
272-
if (part.text or part.file_data or part.inline_data)
308+
if part
309+
and (
310+
part.text
311+
or part.file_data
312+
or part.inline_data
313+
or part.function_call
314+
or part.function_response
315+
)
273316
]
274317

275318

@@ -296,7 +339,9 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part:
296339
raise ValueError(f'Unsupported part type: {part}')
297340

298341

299-
def get_auth_request_function_call(event: Event) -> types.FunctionCall:
342+
def get_auth_request_function_call(
343+
event: Event,
344+
) -> types.FunctionCall | None:
300345
"""Get the special auth request function call from the event."""
301346
if not (event.content and event.content.parts):
302347
return None
@@ -316,8 +361,12 @@ def get_auth_config(
316361
auth_request_function_call: types.FunctionCall,
317362
) -> AuthConfig:
318363
"""Extracts the AuthConfig object from the arguments of the auth request function call."""
319-
if not auth_request_function_call.args or not (
320-
auth_config := auth_request_function_call.args.get('auth_config')
364+
if (
365+
not auth_request_function_call.args
366+
or not isinstance(auth_request_function_call.args, dict)
367+
or not (
368+
auth_config := auth_request_function_call.args.get('auth_config')
369+
)
321370
):
322371
raise ValueError(
323372
f'Cannot get auth config from function call: {auth_request_function_call}'

0 commit comments

Comments
 (0)