Skip to content

Commit ab76f1f

Browse files
committed
Updated examples/adk_agent.py
1 parent b326bb7 commit ab76f1f

File tree

8,187 files changed

+6660028
-19
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

8,187 files changed

+6660028
-19
lines changed

examples/google_adk/adk_agent.py

Lines changed: 186 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,32 @@
66
"""
77

88
from __future__ import annotations
9+
10+
import asyncio
11+
import logging
12+
from collections.abc import Mapping
913
from typing import Any, Dict, TypedDict, cast
14+
1015
from agentlightning import LLM, LitAgent, NamedResources
1116
from agentlightning.types import Rollout, RolloutRawResult
1217

18+
logger = logging.getLogger(__name__)
19+
20+
try: # pragma: no cover - import guarded for optional dependency.
21+
from google.genai import types as genai_types
22+
23+
from google.adk.agents.llm_agent import LlmAgent
24+
from google.adk.apps.app import App
25+
from google.adk.models.lite_llm import LiteLlm
26+
from google.adk.runners import InMemoryRunner
27+
from google.adk.utils.context_utils import Aclosing
28+
29+
_HAS_GOOGLE_ADK = True
30+
except ImportError: # pragma: no cover
31+
genai_types = None # type: ignore[assignment]
32+
LlmAgent = App = LiteLlm = InMemoryRunner = Aclosing = None # type: ignore[assignment]
33+
_HAS_GOOGLE_ADK = False
34+
1335

1436
class AdkTask(TypedDict):
1537
"""
@@ -29,29 +51,172 @@ class AdkTask(TypedDict):
2951

3052

3153
class LitAdkAgent(LitAgent[AdkTask]):
32-
"""Basic ADK + LitAgent example."""
33-
34-
def rollout(self, task: AdkTask, resources: NamedResources, rollout: Rollout) -> RolloutRawResult:
35-
# get the llm resource
36-
llm: LLM = cast(LLM, resources.get("main_llm"))
37-
54+
"""ADK-backed agent that produces observability-friendly rollouts."""
55+
56+
def rollout(self, task: AdkTask, resources: NamedResources, rollout: Rollout) -> RolloutRawResult: # type: ignore[override]
57+
"""Synchronous entry point – forwards to ``rollout_async``."""
58+
59+
if not _HAS_GOOGLE_ADK: # pragma: no cover
60+
raise RuntimeError(
61+
"google-adk>=0.3.0 is required to run this example. "
62+
"Install it with `pip install google-adk` or enable the "
63+
"`adk` optional dependency group."
64+
)
65+
66+
try:
67+
return asyncio.run(self.rollout_async(task, resources, rollout))
68+
except RuntimeError as exc: # pragma: no cover - defensive path
69+
if "asyncio.run()" in str(exc):
70+
raise RuntimeError(
71+
"LitAdkAgent.rollout cannot be executed while an event loop "
72+
"is already running. Call `rollout_async` instead."
73+
) from exc
74+
raise
75+
76+
async def rollout_async( # type: ignore[override]
77+
self,
78+
task: AdkTask,
79+
resources: NamedResources,
80+
rollout: Rollout,
81+
) -> RolloutRawResult:
82+
"""Runs a single rollout by delegating to ADK's orchestration runtime."""
83+
84+
if not _HAS_GOOGLE_ADK: # pragma: no cover
85+
raise RuntimeError(
86+
"google-adk>=0.3.0 is required to run this example. "
87+
"Install it with `pip install google-adk` or enable the "
88+
"`adk` optional dependency group."
89+
)
90+
91+
llm = cast(LLM, resources.get("main_llm"))
3892
question = task["question"]
39-
app_id = task["app_id"]
4093
truth = task["ground_truth"]
94+
app_id = task["app_id"] or "adk_agent_app"
95+
96+
adk_model = self._build_adk_model(llm)
97+
adk_agent = self._build_adk_agent(adk_model, task)
98+
app = App(name=app_id, root_agent=adk_agent)
99+
100+
runner = InMemoryRunner(app=app)
101+
try:
102+
session_state = self._build_session_state(task)
103+
session = await runner.session_service.create_session(
104+
app_name=app.name,
105+
user_id="agent_lightning_user",
106+
state=session_state,
107+
)
108+
109+
message = genai_types.Content( # type: ignore[union-attr]
110+
role="user",
111+
parts=[genai_types.Part.from_text(text=question)], # type: ignore[union-attr]
112+
)
113+
114+
last_response = ""
115+
async with Aclosing( # type: ignore[union-attr]
116+
runner.run_async(
117+
user_id=session.user_id,
118+
session_id=session.id,
119+
new_message=message,
120+
)
121+
) as agen:
122+
async for event in agen:
123+
if event.content and event.content.parts:
124+
text = "".join(part.text or "" for part in event.content.parts).strip()
125+
if text:
126+
last_response = text
127+
128+
reward = self._compute_reward(last_response, truth)
129+
return reward
130+
except Exception as exc: # pragma: no cover - surfaced via training logs
131+
logger.exception("ADK rollout failed: %s", exc)
132+
return 0.0
133+
finally:
134+
await runner.close()
135+
136+
@staticmethod
137+
def _build_adk_model(llm: LLM) -> LiteLlm: # type: ignore[valid-type]
138+
"""Create a LiteLLM-backed ADK model from the Agent-Lightning LLM resource."""
139+
140+
sampling_params = llm.sampling_parameters or {}
141+
temperature: float | None = None
142+
top_p: float | None = None
143+
if isinstance(sampling_params, Mapping):
144+
temperature = sampling_params.get("temperature")
145+
top_p = sampling_params.get("top_p")
146+
147+
llm_kwargs: Dict[str, Any] = {}
148+
if llm.endpoint:
149+
llm_kwargs["api_base"] = llm.endpoint
150+
if llm.api_key:
151+
llm_kwargs["api_key"] = llm.api_key
152+
if temperature is not None:
153+
llm_kwargs["temperature"] = float(temperature)
154+
if top_p is not None:
155+
llm_kwargs["top_p"] = float(top_p)
156+
157+
return LiteLlm(model=llm.model, **llm_kwargs) # type: ignore[arg-type]
158+
159+
def _build_adk_agent(self, adk_model: LiteLlm, task: AdkTask) -> LlmAgent: # type: ignore[valid-type]
160+
"""Construct the ADK agent definition for this rollout."""
161+
162+
instruction = self._compose_instruction(task)
163+
agent_name = task["app_id"] or "adk_agent"
164+
165+
return LlmAgent(
166+
name=agent_name,
167+
model=adk_model,
168+
instruction=instruction,
169+
)
170+
171+
@staticmethod
172+
def _build_session_state(task: AdkTask) -> Dict[str, Any]:
173+
"""Extract session state from the task metadata."""
174+
175+
session_state: Dict[str, Any] = {}
176+
meta = task.get("meta")
177+
if isinstance(meta, Mapping):
178+
session_state.update(meta)
179+
return session_state
180+
181+
def _compose_instruction(self, task: AdkTask) -> str:
182+
"""Compose the instruction prompt for the ADK agent."""
183+
184+
base_instruction = (
185+
"You are an enterprise assistant integrated with Google ADK. "
186+
"Use available tools to reason carefully and produce concise, factual answers. "
187+
"Explain limitations when information is missing."
188+
)
189+
190+
meta = task.get("meta")
191+
if isinstance(meta, Mapping):
192+
# Recognise common fields that may contain additional guidance.
193+
for key in ("instruction", "goal", "context", "description"):
194+
if meta.get(key):
195+
return f"{base_instruction}\n\nAdditional context:\n{meta[key]}"
196+
197+
return base_instruction
198+
199+
@staticmethod
200+
def _compute_reward(answer: str, truth: str) -> float:
201+
"""Simple lexical reward comparing the ADK answer with ground truth."""
202+
203+
if not answer or not truth:
204+
return 0.0
205+
206+
answer_lower = answer.lower()
207+
truth_lower = truth.lower()
208+
return 1.0 if truth_lower in answer_lower else 0.0
41209

42-
# TODO: hook up real ADK orchestration later
43-
# for now just simulate an action string
44-
action = f"adk://{app_id}?plan={question}"
45-
46-
# quick check for correctness
47-
reward = 1.0 if truth and truth.lower() in action.lower() else 0.0
48210

49-
return reward
211+
# Minimal smoke-test entry point (manual run)
212+
if __name__ == "__main__": # pragma: no cover - manual verification helper
213+
import sys
50214

215+
if not _HAS_GOOGLE_ADK:
216+
sys.exit(
217+
"google-adk is not installed. Run `pip install google-adk` before executing this script."
218+
)
51219

52-
# Minimal smoke-test entry point (manual run)
53-
if __name__ == "__main__":
54-
# very minimal test run
55220
sample_task: AdkTask = {
56221
"question": "Create a calendar event for Monday 10am titled 'Standup'",
57222
"app_id": "sample_calendar_app",
@@ -62,15 +227,17 @@ def rollout(self, task: AdkTask, resources: NamedResources, rollout: Rollout) ->
62227
resources: NamedResources = {
63228
"main_llm": LLM(
64229
endpoint="http://localhost:8000/v1",
65-
model="meta-llama/Meta-Llama-3-8B-Instruct"
230+
model="gpt-4.1-mini",
231+
sampling_parameters={"temperature": 0.0},
232+
api_key="dummy-key",
66233
),
67234
}
68235

69236
class DummyRollout:
70237
pass
71238

72239
agent = LitAdkAgent()
73-
result = agent.rollout(sample_task, resources, cast(Rollout, DummyRollout()))
240+
result = asyncio.run(agent.rollout_async(sample_task, resources, cast(Rollout, DummyRollout())))
74241
print("Reward:", result)
75242

76243

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This is a stub package designed to roughly emulate the _yaml
2+
# extension module, which previously existed as a standalone module
3+
# and has been moved into the `yaml` package namespace.
4+
# It does not perfectly mimic its old counterpart, but should get
5+
# close enough for anyone who's relying on it even when they shouldn't.
6+
import yaml
7+
8+
# in some circumstances, the yaml module we imoprted may be from a different version, so we need
9+
# to tread carefully when poking at it here (it may not have the attributes we expect)
10+
if not getattr(yaml, '__with_libyaml__', False):
11+
from sys import version_info
12+
13+
exc = ModuleNotFoundError if version_info >= (3, 6) else ImportError
14+
raise exc("No module named '_yaml'")
15+
else:
16+
from yaml._yaml import *
17+
import warnings
18+
warnings.warn(
19+
'The _yaml extension module is now located at yaml._yaml'
20+
' and its location is subject to change. To use the'
21+
' LibYAML-based parser and emitter, import from `yaml`:'
22+
' `from yaml import CLoader as Loader, CDumper as Dumper`.',
23+
DeprecationWarning
24+
)
25+
del warnings
26+
# Don't `del yaml` here because yaml is actually an existing
27+
# namespace member of _yaml.
28+
29+
__name__ = '_yaml'
30+
# If the module is top-level (i.e. not a part of any specific package)
31+
# then the attribute should be set to ''.
32+
# https://docs.python.org/3.8/library/types.html
33+
__package__ = ''
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pip

0 commit comments

Comments
 (0)