Skip to content

Commit f39df41

Browse files
Jacksunweicopybara-github
authored andcommitted
feat(conformance): Supports content and state_delta in TestCase.user_messages and initial_state for session creation
PiperOrigin-RevId: 808827170
1 parent 1a91bb2 commit f39df41

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

src/google/adk/cli/conformance/cli_create.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,33 @@ async def _create_conformance_test_files(
4646
async with AdkWebServerClient() as client:
4747
# Create a new session for the test
4848
session = await client.create_session(
49-
app_name=test_case.test_spec.agent, user_id=user_id, state={}
49+
app_name=test_case.test_spec.agent,
50+
user_id=user_id,
51+
state=test_case.test_spec.initial_state,
5052
)
5153

5254
# Run the agent with the user messages
5355
for user_message_index, user_message in enumerate(
5456
test_case.test_spec.user_messages
5557
):
56-
content = types.Content(
57-
parts=[types.Part(text=user_message)], role="user"
58-
)
58+
# Create content from UserMessage object
59+
if user_message.content is not None:
60+
content = user_message.content
61+
elif user_message.text is not None:
62+
content = types.UserContent(parts=[types.Part(text=user_message.text)])
63+
else:
64+
raise ValueError(
65+
f"UserMessage at index {user_message_index} has neither text nor"
66+
" content"
67+
)
68+
5969
async for _ in client.run_agent(
6070
RunAgentRequest(
6171
app_name=test_case.test_spec.agent,
6272
user_id=user_id,
6373
session_id=session.id,
6474
new_message=content,
75+
state_delta=user_message.state_delta,
6576
),
6677
mode="record",
6778
test_case_dir=str(test_case_dir),

src/google/adk/cli/conformance/cli_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,24 @@ async def _run_user_messages(
120120
for user_message_index, user_message in enumerate(
121121
test_case.test_spec.user_messages
122122
):
123-
content = types.UserContent(parts=[types.Part(text=user_message)])
123+
# Create content from UserMessage object
124+
if user_message.content is not None:
125+
content = user_message.content
126+
elif user_message.text is not None:
127+
content = types.UserContent(parts=[types.Part(text=user_message.text)])
128+
else:
129+
raise ValueError(
130+
f"UserMessage at index {user_message_index} has neither text nor"
131+
" content"
132+
)
124133

125134
request = RunAgentRequest(
126135
app_name=test_case.test_spec.agent,
127136
user_id=self.user_id,
128137
session_id=session_id,
129138
new_message=content,
130139
streaming=False,
140+
state_delta=user_message.state_delta,
131141
)
132142

133143
# Run the agent but don't collect events here
@@ -193,7 +203,9 @@ async def _run_test_case_replay(self, test_case: TestCase) -> _TestResult:
193203
try:
194204
# Create session
195205
session = await self.client.create_session(
196-
app_name=test_case.test_spec.agent, user_id=self.user_id, state={}
206+
app_name=test_case.test_spec.agent,
207+
user_id=self.user_id,
208+
state=test_case.test_spec.initial_state,
197209
)
198210

199211
# Run each user message

src/google/adk/cli/conformance/test_case.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,27 @@
1616

1717
from dataclasses import dataclass
1818
from pathlib import Path
19+
from typing import Any
20+
from typing import Optional
1921

22+
from google.genai import types
2023
from pydantic import BaseModel
2124
from pydantic import ConfigDict
25+
from pydantic import Field
26+
27+
28+
class UserMessage(BaseModel):
29+
30+
# oneof fields - start
31+
text: Optional[str] = None
32+
"""The user message in text."""
33+
34+
content: Optional[types.UserContent] = None
35+
"""The user message in types.Content."""
36+
# oneof fields - end
37+
38+
state_delta: Optional[dict[str, Any]] = None
39+
"""The state changes when running this user message."""
2240

2341

2442
class TestSpec(BaseModel):
@@ -38,7 +56,10 @@ class TestSpec(BaseModel):
3856
agent: str
3957
"""Name of the ADK agent to test against."""
4058

41-
user_messages: list[str]
59+
initial_state: dict[str, Any] = Field(default_factory=dict)
60+
"""The initial state key-value pairs in the creation_session request."""
61+
62+
user_messages: list[UserMessage] = Field(default_factory=list)
4263
"""Sequence of user messages to send to the agent during test execution."""
4364

4465

0 commit comments

Comments
 (0)