Skip to content

Commit 9c2b709

Browse files
Jacksunweicopybara-github
authored andcommitted
refactor(comformance): Improves field comparison logic in replay plugin with nested exclude dict from pydantic v2
Also use `ReplayConfigError` to replace `ValueError`s PiperOrigin-RevId: 808750606
1 parent 21c26f9 commit 9c2b709

File tree

1 file changed

+41
-63
lines changed

1 file changed

+41
-63
lines changed

src/google/adk/cli/plugins/replay_plugin.py

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ class ReplayVerificationError(Exception):
5151
pass
5252

5353

54+
class ReplayConfigError(Exception):
55+
"""Exception raised when replay configuration is invalid or missing."""
56+
57+
pass
58+
59+
5460
class _InvocationReplayState(BaseModel):
5561
"""Per-invocation replay state to isolate concurrent runs."""
5662

@@ -93,7 +99,7 @@ async def before_model_callback(
9399
return None
94100

95101
if (state := self._get_invocation_state(callback_context)) is None:
96-
raise ValueError(
102+
raise ReplayConfigError(
97103
"Replay state not initialized. Ensure before_run created it."
98104
)
99105

@@ -122,7 +128,7 @@ async def before_tool_callback(
122128
return None
123129

124130
if (state := self._get_invocation_state(tool_context)) is None:
125-
raise ValueError(
131+
raise ReplayConfigError(
126132
"Replay state not initialized. Ensure before_run created it."
127133
)
128134

@@ -188,20 +194,24 @@ def _load_invocation_state(
188194
msg_index = config.get("user_message_index")
189195

190196
if not case_dir or msg_index is None:
191-
raise ValueError("Replay parameters are missing from session state")
197+
raise ReplayConfigError(
198+
"Replay parameters are missing from session state"
199+
)
192200

193201
# Load recordings
194202
recordings_file = Path(case_dir) / "generated-recordings.yaml"
195203

196204
if not recordings_file.exists():
197-
raise ValueError(f"Recordings file not found: {recordings_file}")
205+
raise ReplayConfigError(f"Recordings file not found: {recordings_file}")
198206

199207
try:
200208
with recordings_file.open("r", encoding="utf-8") as f:
201209
recordings_data = yaml.safe_load(f)
202210
recordings = Recordings.model_validate(recordings_data)
203211
except Exception as e:
204-
raise ValueError(f"Failed to load recordings from {recordings_file}: {e}")
212+
raise ReplayConfigError(
213+
f"Failed to load recordings from {recordings_file}: {e}"
214+
) from e
205215

206216
# Load and store invocation state
207217
state = _InvocationReplayState(
@@ -320,62 +330,28 @@ def _verify_llm_request_match(
320330
agent_index: int,
321331
) -> None:
322332
"""Verify that the current LLM request exactly matches the recorded one."""
323-
self._verify_config_match(
324-
recorded_request, current_request, agent_name, agent_index
325-
)
326-
handled_fields: set[str] = {"config"}
327-
ignored_fields: set[str] = {"live_connect_config"}
328-
exclude_fields = handled_fields | ignored_fields
329-
if not self._compare_fields(
330-
recorded_request, current_request, exclude_fields=exclude_fields
331-
):
332-
raise ValueError(
333-
f"LLM request mismatch for agent '{agent_name}' (index"
334-
f" {agent_index}): "
335-
"recorded:"
336-
f" {recorded_request.model_dump(exclude_none=True, exclude=exclude_fields)},"
337-
" current:"
338-
f" {current_request.model_dump(exclude_none=True, exclude=exclude_fields)}"
339-
)
340-
341-
def _compare_fields(
342-
self,
343-
obj1: BaseModel,
344-
obj2: BaseModel,
345-
*,
346-
exclude_fields: Optional[set[str]] = None,
347-
) -> bool:
348-
"""Compare two Pydantic models excluding specified fields."""
349-
exclude_fields = exclude_fields or set()
350-
dict1 = obj1.model_dump(exclude_none=True, exclude=exclude_fields)
351-
dict2 = obj2.model_dump(exclude_none=True, exclude=exclude_fields)
352-
return dict1 == dict2
353-
354-
def _verify_config_match(
355-
self,
356-
recorded_request: LlmRequest,
357-
current_request: LlmRequest,
358-
agent_name: str,
359-
agent_index: int,
360-
) -> None:
361-
"""Verify that the config matches between recorded and current requests."""
362-
# Fields to ignore when comparing GenerateContentConfig (denylist approach)
363-
ignored_fields: set[str] = {
364-
"http_options",
365-
"labels",
333+
# Comprehensive exclude dict for all fields that can differ between runs
334+
excluded_fields = {
335+
"live_connect_config": True,
336+
"config": { # some config fields can vary per run
337+
"http_options": True,
338+
"labels": True,
339+
},
366340
}
367341

368-
if not self._compare_fields(
369-
recorded_request.config,
370-
current_request.config,
371-
exclude_fields=ignored_fields,
372-
):
373-
raise ValueError(
374-
f"Config mismatch for agent '{agent_name}' (index {agent_index}): "
375-
"recorded:"
376-
f" {recorded_request.config.model_dump(exclude_none=True, exclude=ignored_fields)},"
377-
" current:"
378-
f" {current_request.config.model_dump(exclude_none=True, exclude=ignored_fields)}"
342+
# Compare using model dumps with nested exclude dict
343+
recorded_dict = recorded_request.model_dump(
344+
exclude_none=True, exclude=excluded_fields, exclude_defaults=True
345+
)
346+
current_dict = current_request.model_dump(
347+
exclude_none=True, exclude=excluded_fields, exclude_defaults=True
348+
)
349+
350+
if recorded_dict != current_dict:
351+
raise ReplayVerificationError(
352+
f"""LLM request mismatch for agent '{agent_name}' (index {agent_index}):
353+
recorded: {recorded_dict}
354+
current: {current_dict}"""
379355
)
380356

381357
def _verify_tool_call_match(
@@ -389,12 +365,14 @@ def _verify_tool_call_match(
389365
"""Verify that the current tool call exactly matches the recorded one."""
390366
if recorded_call.name != tool_name:
391367
raise ReplayVerificationError(
392-
f"Tool name mismatch for agent '{agent_name}' at index {agent_index}:"
393-
f" recorded='{recorded_call.name}', current='{tool_name}'"
368+
f"""Tool name mismatch for agent '{agent_name}' at index {agent_index}:
369+
recorded: '{recorded_call.name}'
370+
current: '{tool_name}'"""
394371
)
395372

396373
if recorded_call.args != tool_args:
397374
raise ReplayVerificationError(
398-
f"Tool args mismatch for agent '{agent_name}' at index {agent_index}:"
399-
f" recorded={recorded_call.args}, current={tool_args}"
375+
f"""Tool args mismatch for agent '{agent_name}' at index {agent_index}:
376+
recorded: {recorded_call.args}
377+
current: {tool_args}"""
400378
)

0 commit comments

Comments
 (0)