Skip to content

Commit f456b70

Browse files
Add ADK model handler (#37917)
* Add ADK model handler * Small cleanup * CHANGES * Fix up some tests * Linting * lint * remove disclaimer, we don't do previews like this * Fix gemini comments * Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update sdks/python/apache_beam/ml/inference/agent_development_kit.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update sdks/python/apache_beam/ml/inference/agent_development_kit.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * tests + lint * Pipe operator --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 1fb3088 commit f456b70

File tree

4 files changed

+632
-0
lines changed

4 files changed

+632
-0
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969

7070
## New Features / Improvements
7171

72+
* Added `ADKAgentModelHandler` for running Google Agent Development Kit (ADK) agents (Python) ([#37917](https://github.com/apache/beam/issues/37917)).
7273
* (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)).
7374
* (Python) Added a pipeline option `--experiments=pip_no_build_isolation` to disable build isolation when installing dependencies in the runtime environment ([#37331](https://github.com/apache/beam/issues/37331)).
7475
* (Go) Added OrderedListState support to the Go SDK stateful DoFn API ([#37629](https://github.com/apache/beam/issues/37629)).
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""ModelHandler for running agents built with the Google Agent Development Kit.
19+
20+
This module provides :class:`ADKAgentModelHandler`, a Beam
21+
:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK
22+
:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the
23+
:class:`~apache_beam.ml.inference.base.RunInference` transform.
24+
25+
Typical usage::
26+
27+
import apache_beam as beam
28+
from apache_beam.ml.inference.base import RunInference
29+
from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler
30+
from google.adk.agents import LlmAgent
31+
32+
agent = LlmAgent(
33+
name="my_agent",
34+
model="gemini-2.0-flash",
35+
instruction="You are a helpful assistant.",
36+
)
37+
38+
with beam.Pipeline() as p:
39+
results = (
40+
p
41+
| beam.Create(["What is the capital of France?"])
42+
| RunInference(ADKAgentModelHandler(agent=agent))
43+
)
44+
45+
If your agent contains state that is not picklable (e.g. tool closures that
46+
capture unpicklable objects), pass a zero-arg factory callable instead::
47+
48+
handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...))
49+
50+
"""
51+
52+
import asyncio
53+
import logging
54+
import uuid
55+
from collections.abc import Callable
56+
from collections.abc import Iterable
57+
from collections.abc import Sequence
58+
from typing import Any
59+
from typing import Optional
60+
61+
from apache_beam.ml.inference.base import ModelHandler
62+
from apache_beam.ml.inference.base import PredictionResult
63+
64+
try:
65+
from google.adk import sessions
66+
from google.adk.agents import Agent
67+
from google.adk.runners import Runner
68+
from google.adk.sessions import BaseSessionService
69+
from google.adk.sessions import InMemorySessionService
70+
from google.genai.types import Content as genai_Content
71+
from google.genai.types import Part as genai_Part
72+
ADK_AVAILABLE = True
73+
except ImportError:
74+
ADK_AVAILABLE = False
75+
genai_Content = Any # type: ignore[assignment, misc]
76+
genai_Part = Any # type: ignore[assignment, misc]
77+
Agent = None
78+
79+
LOGGER = logging.getLogger("ADKAgentModelHandler")
80+
81+
# Type alias for an agent or factory that produces one
82+
_AgentOrFactory = Agent | Callable[[], Agent]
83+
84+
85+
class ADKAgentModelHandler(ModelHandler[str | genai_Content,
86+
PredictionResult,
87+
"Runner"]):
88+
"""ModelHandler for running ADK agents with the Beam RunInference transform.
89+
90+
Accepts either a fully constructed :class:`google.adk.agents.Agent` or a
91+
zero-arg factory callable that produces one. The factory form is useful when
92+
the agent contains state that is not picklable and therefore cannot be
93+
serialized alongside the pipeline graph.
94+
95+
Each call to :meth:`run_inference` invokes the agent once per element in the
96+
batch. By default every invocation uses a fresh, isolated session (stateless).
97+
Stateful multi-turn conversations can be achieved by passing a ``session_id``
98+
key inside ``inference_args``; elements sharing the same ``session_id`` will
99+
continue the same conversation history.
100+
101+
Args:
102+
agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a
103+
zero-arg callable that returns one. The callable form defers agent
104+
construction to worker ``load_model`` time, which is useful when the
105+
agent cannot be serialized.
106+
app_name: The ADK application name used to namespace sessions. Defaults to
107+
``"beam_inference"``.
108+
session_service_factory: Optional zero-arg callable returning a
109+
:class:`~google.adk.sessions.BaseSessionService`. When ``None``, an
110+
:class:`~google.adk.sessions.InMemorySessionService` is created
111+
automatically.
112+
min_batch_size: Optional minimum batch size.
113+
max_batch_size: Optional maximum batch size.
114+
max_batch_duration_secs: Optional maximum time to buffer a batch before
115+
emitting; used in streaming contexts.
116+
max_batch_weight: Optional maximum total weight of a batch.
117+
element_size_fn: Optional function that returns the size (weight) of an
118+
element.
119+
"""
120+
def __init__(
121+
self,
122+
agent: _AgentOrFactory,
123+
app_name: str = "beam_inference",
124+
session_service_factory: Optional[Callable[[],
125+
"BaseSessionService"]] = None,
126+
*,
127+
min_batch_size: Optional[int] = None,
128+
max_batch_size: Optional[int] = None,
129+
max_batch_duration_secs: Optional[int] = None,
130+
max_batch_weight: Optional[int] = None,
131+
element_size_fn: Optional[Callable[[Any], int]] = None,
132+
**kwargs):
133+
if not ADK_AVAILABLE:
134+
raise ImportError(
135+
"google-adk is required to use ADKAgentModelHandler. "
136+
"Install it with: pip install google-adk")
137+
138+
if agent is None:
139+
raise ValueError("'agent' must be an Agent instance or a callable.")
140+
141+
self._agent_or_factory = agent
142+
self._app_name = app_name
143+
self._session_service_factory = session_service_factory
144+
145+
super().__init__(
146+
min_batch_size=min_batch_size,
147+
max_batch_size=max_batch_size,
148+
max_batch_duration_secs=max_batch_duration_secs,
149+
max_batch_weight=max_batch_weight,
150+
element_size_fn=element_size_fn,
151+
**kwargs)
152+
153+
def load_model(self) -> "Runner":
154+
"""Instantiates the ADK Runner on the worker.
155+
156+
Resolves the agent (calling the factory if a callable was provided), then
157+
creates a :class:`~google.adk.runners.Runner` backed by the configured
158+
session service.
159+
160+
Returns:
161+
A fully initialised :class:`~google.adk.runners.Runner`.
162+
"""
163+
if callable(self._agent_or_factory) and not isinstance(
164+
self._agent_or_factory, Agent):
165+
agent = self._agent_or_factory()
166+
else:
167+
agent = self._agent_or_factory
168+
169+
if self._session_service_factory is not None:
170+
session_service = self._session_service_factory()
171+
else:
172+
session_service = InMemorySessionService()
173+
174+
runner = Runner(
175+
agent=agent,
176+
app_name=self._app_name,
177+
session_service=session_service,
178+
)
179+
LOGGER.info(
180+
"Loaded ADK Runner for agent '%s' (app_name='%s')",
181+
agent.name,
182+
self._app_name,
183+
)
184+
return runner
185+
186+
def run_inference(
187+
self,
188+
batch: Sequence[str | genai_Content],
189+
model: "Runner",
190+
inference_args: Optional[dict[str, Any]] = None,
191+
) -> Iterable[PredictionResult]:
192+
"""Runs the ADK agent on each element in the batch.
193+
194+
Each element is sent to the agent as a new user turn. The final response
195+
text from the agent is returned as the ``inference`` field of a
196+
:class:`~apache_beam.ml.inference.base.PredictionResult`.
197+
198+
Args:
199+
batch: A sequence of inputs, each of which is either a ``str`` (the user
200+
message text) or a :class:`google.genai.types.Content` object (for
201+
richer multi-part messages).
202+
model: The :class:`~google.adk.runners.Runner` returned by
203+
:meth:`load_model`.
204+
inference_args: Optional dict of extra arguments. Supported keys:
205+
206+
- ``"session_id"`` (:class:`str`): If supplied, all elements in this
207+
batch share this session ID, enabling stateful multi-turn
208+
conversations. If omitted, each element receives a unique auto-
209+
generated session ID.
210+
- ``"user_id"`` (:class:`str`): The user identifier to pass to the
211+
runner. Defaults to ``"beam_user"``.
212+
213+
Returns:
214+
An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`,
215+
one per input element.
216+
"""
217+
if inference_args is None:
218+
inference_args = {}
219+
220+
user_id: str = inference_args.get("user_id", "beam_user")
221+
agent_invocations = []
222+
elements_with_sessions = []
223+
224+
for element in batch:
225+
session_id: str = inference_args.get("session_id", str(uuid.uuid4()))
226+
227+
# Ensure a session exists for this invocation
228+
try:
229+
model.session_service.create_session(
230+
app_name=self._app_name,
231+
user_id=user_id,
232+
session_id=session_id,
233+
)
234+
except sessions.SessionExistsError:
235+
# It's okay if the session already exists for shared session IDs.
236+
pass
237+
238+
# Wrap plain strings in a Content object
239+
if isinstance(element, str):
240+
message = genai_Content(role="user", parts=[genai_Part(text=element)])
241+
else:
242+
# Assume the caller has already constructed a types.Content object
243+
message = element
244+
245+
agent_invocations.append(
246+
self._invoke_agent(model, user_id, session_id, message))
247+
elements_with_sessions.append(element)
248+
249+
# Run all agent invocations concurrently
250+
async def _run_concurrently():
251+
return await asyncio.gather(*agent_invocations)
252+
253+
response_texts = asyncio.run(_run_concurrently())
254+
255+
results = []
256+
for i, element in enumerate(elements_with_sessions):
257+
results.append(
258+
PredictionResult(
259+
example=element,
260+
inference=response_texts[i],
261+
model_id=model.agent.name,
262+
))
263+
264+
return results
265+
266+
@staticmethod
267+
async def _invoke_agent(
268+
runner: "Runner",
269+
user_id: str,
270+
session_id: str,
271+
message: genai_Content,
272+
) -> Optional[str]:
273+
"""Drives the ADK event loop and returns the final response text.
274+
275+
Args:
276+
runner: The ADK Runner to invoke.
277+
user_id: The user ID for this invocation.
278+
session_id: The session ID for this invocation.
279+
message: The :class:`google.genai.types.Content` to send.
280+
281+
Returns:
282+
The text of the agent's final response, or ``None`` if the agent
283+
produced no final text response.
284+
"""
285+
async for event in runner.run_async(
286+
user_id=user_id,
287+
session_id=session_id,
288+
new_message=message,
289+
):
290+
if event.is_final_response():
291+
if event.content:
292+
return event.content.text
293+
return None
294+
295+
def get_metrics_namespace(self) -> str:
296+
return "ADKAgentModelHandler"

0 commit comments

Comments
 (0)