|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +"""ModelHandler for running agents built with the Google Agent Development Kit. |
| 19 | +
|
| 20 | +This module provides :class:`ADKAgentModelHandler`, a Beam |
| 21 | +:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK |
| 22 | +:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the |
| 23 | +:class:`~apache_beam.ml.inference.base.RunInference` transform. |
| 24 | +
|
| 25 | +Typical usage:: |
| 26 | +
|
| 27 | + import apache_beam as beam |
| 28 | + from apache_beam.ml.inference.base import RunInference |
| 29 | + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler |
| 30 | + from google.adk.agents import LlmAgent |
| 31 | +
|
| 32 | + agent = LlmAgent( |
| 33 | + name="my_agent", |
| 34 | + model="gemini-2.0-flash", |
| 35 | + instruction="You are a helpful assistant.", |
| 36 | + ) |
| 37 | +
|
| 38 | + with beam.Pipeline() as p: |
| 39 | + results = ( |
| 40 | + p |
| 41 | + | beam.Create(["What is the capital of France?"]) |
| 42 | + | RunInference(ADKAgentModelHandler(agent=agent)) |
| 43 | + ) |
| 44 | +
|
| 45 | +If your agent contains state that is not picklable (e.g. tool closures that |
| 46 | +capture unpicklable objects), pass a zero-arg factory callable instead:: |
| 47 | +
|
| 48 | + handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...)) |
| 49 | +
|
| 50 | +""" |
| 51 | + |
| 52 | +import asyncio |
| 53 | +import logging |
| 54 | +import uuid |
| 55 | +from collections.abc import Callable |
| 56 | +from collections.abc import Iterable |
| 57 | +from collections.abc import Sequence |
| 58 | +from typing import Any |
| 59 | +from typing import Optional |
| 60 | + |
| 61 | +from apache_beam.ml.inference.base import ModelHandler |
| 62 | +from apache_beam.ml.inference.base import PredictionResult |
| 63 | + |
| 64 | +try: |
| 65 | + from google.adk import sessions |
| 66 | + from google.adk.agents import Agent |
| 67 | + from google.adk.runners import Runner |
| 68 | + from google.adk.sessions import BaseSessionService |
| 69 | + from google.adk.sessions import InMemorySessionService |
| 70 | + from google.genai.types import Content as genai_Content |
| 71 | + from google.genai.types import Part as genai_Part |
| 72 | + ADK_AVAILABLE = True |
| 73 | +except ImportError: |
| 74 | + ADK_AVAILABLE = False |
| 75 | + genai_Content = Any # type: ignore[assignment, misc] |
| 76 | + genai_Part = Any # type: ignore[assignment, misc] |
| 77 | + Agent = None |
| 78 | + |
| 79 | +LOGGER = logging.getLogger("ADKAgentModelHandler") |
| 80 | + |
| 81 | +# Type alias for an agent or factory that produces one |
| 82 | +_AgentOrFactory = Agent | Callable[[], Agent] |
| 83 | + |
| 84 | + |
| 85 | +class ADKAgentModelHandler(ModelHandler[str | genai_Content, |
| 86 | + PredictionResult, |
| 87 | + "Runner"]): |
| 88 | + """ModelHandler for running ADK agents with the Beam RunInference transform. |
| 89 | +
|
| 90 | + Accepts either a fully constructed :class:`google.adk.agents.Agent` or a |
| 91 | + zero-arg factory callable that produces one. The factory form is useful when |
| 92 | + the agent contains state that is not picklable and therefore cannot be |
| 93 | + serialized alongside the pipeline graph. |
| 94 | +
|
| 95 | + Each call to :meth:`run_inference` invokes the agent once per element in the |
| 96 | + batch. By default every invocation uses a fresh, isolated session (stateless). |
| 97 | + Stateful multi-turn conversations can be achieved by passing a ``session_id`` |
| 98 | + key inside ``inference_args``; elements sharing the same ``session_id`` will |
| 99 | + continue the same conversation history. |
| 100 | +
|
| 101 | + Args: |
| 102 | + agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a |
| 103 | + zero-arg callable that returns one. The callable form defers agent |
| 104 | + construction to worker ``load_model`` time, which is useful when the |
| 105 | + agent cannot be serialized. |
| 106 | + app_name: The ADK application name used to namespace sessions. Defaults to |
| 107 | + ``"beam_inference"``. |
| 108 | + session_service_factory: Optional zero-arg callable returning a |
| 109 | + :class:`~google.adk.sessions.BaseSessionService`. When ``None``, an |
| 110 | + :class:`~google.adk.sessions.InMemorySessionService` is created |
| 111 | + automatically. |
| 112 | + min_batch_size: Optional minimum batch size. |
| 113 | + max_batch_size: Optional maximum batch size. |
| 114 | + max_batch_duration_secs: Optional maximum time to buffer a batch before |
| 115 | + emitting; used in streaming contexts. |
| 116 | + max_batch_weight: Optional maximum total weight of a batch. |
| 117 | + element_size_fn: Optional function that returns the size (weight) of an |
| 118 | + element. |
| 119 | + """ |
| 120 | + def __init__( |
| 121 | + self, |
| 122 | + agent: _AgentOrFactory, |
| 123 | + app_name: str = "beam_inference", |
| 124 | + session_service_factory: Optional[Callable[[], |
| 125 | + "BaseSessionService"]] = None, |
| 126 | + *, |
| 127 | + min_batch_size: Optional[int] = None, |
| 128 | + max_batch_size: Optional[int] = None, |
| 129 | + max_batch_duration_secs: Optional[int] = None, |
| 130 | + max_batch_weight: Optional[int] = None, |
| 131 | + element_size_fn: Optional[Callable[[Any], int]] = None, |
| 132 | + **kwargs): |
| 133 | + if not ADK_AVAILABLE: |
| 134 | + raise ImportError( |
| 135 | + "google-adk is required to use ADKAgentModelHandler. " |
| 136 | + "Install it with: pip install google-adk") |
| 137 | + |
| 138 | + if agent is None: |
| 139 | + raise ValueError("'agent' must be an Agent instance or a callable.") |
| 140 | + |
| 141 | + self._agent_or_factory = agent |
| 142 | + self._app_name = app_name |
| 143 | + self._session_service_factory = session_service_factory |
| 144 | + |
| 145 | + super().__init__( |
| 146 | + min_batch_size=min_batch_size, |
| 147 | + max_batch_size=max_batch_size, |
| 148 | + max_batch_duration_secs=max_batch_duration_secs, |
| 149 | + max_batch_weight=max_batch_weight, |
| 150 | + element_size_fn=element_size_fn, |
| 151 | + **kwargs) |
| 152 | + |
| 153 | + def load_model(self) -> "Runner": |
| 154 | + """Instantiates the ADK Runner on the worker. |
| 155 | +
|
| 156 | + Resolves the agent (calling the factory if a callable was provided), then |
| 157 | + creates a :class:`~google.adk.runners.Runner` backed by the configured |
| 158 | + session service. |
| 159 | +
|
| 160 | + Returns: |
| 161 | + A fully initialised :class:`~google.adk.runners.Runner`. |
| 162 | + """ |
| 163 | + if callable(self._agent_or_factory) and not isinstance( |
| 164 | + self._agent_or_factory, Agent): |
| 165 | + agent = self._agent_or_factory() |
| 166 | + else: |
| 167 | + agent = self._agent_or_factory |
| 168 | + |
| 169 | + if self._session_service_factory is not None: |
| 170 | + session_service = self._session_service_factory() |
| 171 | + else: |
| 172 | + session_service = InMemorySessionService() |
| 173 | + |
| 174 | + runner = Runner( |
| 175 | + agent=agent, |
| 176 | + app_name=self._app_name, |
| 177 | + session_service=session_service, |
| 178 | + ) |
| 179 | + LOGGER.info( |
| 180 | + "Loaded ADK Runner for agent '%s' (app_name='%s')", |
| 181 | + agent.name, |
| 182 | + self._app_name, |
| 183 | + ) |
| 184 | + return runner |
| 185 | + |
| 186 | + def run_inference( |
| 187 | + self, |
| 188 | + batch: Sequence[str | genai_Content], |
| 189 | + model: "Runner", |
| 190 | + inference_args: Optional[dict[str, Any]] = None, |
| 191 | + ) -> Iterable[PredictionResult]: |
| 192 | + """Runs the ADK agent on each element in the batch. |
| 193 | +
|
| 194 | + Each element is sent to the agent as a new user turn. The final response |
| 195 | + text from the agent is returned as the ``inference`` field of a |
| 196 | + :class:`~apache_beam.ml.inference.base.PredictionResult`. |
| 197 | +
|
| 198 | + Args: |
| 199 | + batch: A sequence of inputs, each of which is either a ``str`` (the user |
| 200 | + message text) or a :class:`google.genai.types.Content` object (for |
| 201 | + richer multi-part messages). |
| 202 | + model: The :class:`~google.adk.runners.Runner` returned by |
| 203 | + :meth:`load_model`. |
| 204 | + inference_args: Optional dict of extra arguments. Supported keys: |
| 205 | +
|
| 206 | + - ``"session_id"`` (:class:`str`): If supplied, all elements in this |
| 207 | + batch share this session ID, enabling stateful multi-turn |
| 208 | + conversations. If omitted, each element receives a unique auto- |
| 209 | + generated session ID. |
| 210 | + - ``"user_id"`` (:class:`str`): The user identifier to pass to the |
| 211 | + runner. Defaults to ``"beam_user"``. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`, |
| 215 | + one per input element. |
| 216 | + """ |
| 217 | + if inference_args is None: |
| 218 | + inference_args = {} |
| 219 | + |
| 220 | + user_id: str = inference_args.get("user_id", "beam_user") |
| 221 | + agent_invocations = [] |
| 222 | + elements_with_sessions = [] |
| 223 | + |
| 224 | + for element in batch: |
| 225 | + session_id: str = inference_args.get("session_id", str(uuid.uuid4())) |
| 226 | + |
| 227 | + # Ensure a session exists for this invocation |
| 228 | + try: |
| 229 | + model.session_service.create_session( |
| 230 | + app_name=self._app_name, |
| 231 | + user_id=user_id, |
| 232 | + session_id=session_id, |
| 233 | + ) |
| 234 | + except sessions.SessionExistsError: |
| 235 | + # It's okay if the session already exists for shared session IDs. |
| 236 | + pass |
| 237 | + |
| 238 | + # Wrap plain strings in a Content object |
| 239 | + if isinstance(element, str): |
| 240 | + message = genai_Content(role="user", parts=[genai_Part(text=element)]) |
| 241 | + else: |
| 242 | + # Assume the caller has already constructed a types.Content object |
| 243 | + message = element |
| 244 | + |
| 245 | + agent_invocations.append( |
| 246 | + self._invoke_agent(model, user_id, session_id, message)) |
| 247 | + elements_with_sessions.append(element) |
| 248 | + |
| 249 | + # Run all agent invocations concurrently |
| 250 | + async def _run_concurrently(): |
| 251 | + return await asyncio.gather(*agent_invocations) |
| 252 | + |
| 253 | + response_texts = asyncio.run(_run_concurrently()) |
| 254 | + |
| 255 | + results = [] |
| 256 | + for i, element in enumerate(elements_with_sessions): |
| 257 | + results.append( |
| 258 | + PredictionResult( |
| 259 | + example=element, |
| 260 | + inference=response_texts[i], |
| 261 | + model_id=model.agent.name, |
| 262 | + )) |
| 263 | + |
| 264 | + return results |
| 265 | + |
| 266 | + @staticmethod |
| 267 | + async def _invoke_agent( |
| 268 | + runner: "Runner", |
| 269 | + user_id: str, |
| 270 | + session_id: str, |
| 271 | + message: genai_Content, |
| 272 | + ) -> Optional[str]: |
| 273 | + """Drives the ADK event loop and returns the final response text. |
| 274 | +
|
| 275 | + Args: |
| 276 | + runner: The ADK Runner to invoke. |
| 277 | + user_id: The user ID for this invocation. |
| 278 | + session_id: The session ID for this invocation. |
| 279 | + message: The :class:`google.genai.types.Content` to send. |
| 280 | +
|
| 281 | + Returns: |
| 282 | + The text of the agent's final response, or ``None`` if the agent |
| 283 | + produced no final text response. |
| 284 | + """ |
| 285 | + async for event in runner.run_async( |
| 286 | + user_id=user_id, |
| 287 | + session_id=session_id, |
| 288 | + new_message=message, |
| 289 | + ): |
| 290 | + if event.is_final_response(): |
| 291 | + if event.content: |
| 292 | + return event.content.text |
| 293 | + return None |
| 294 | + |
| 295 | + def get_metrics_namespace(self) -> str: |
| 296 | + return "ADKAgentModelHandler" |
0 commit comments