@@ -51,6 +51,12 @@ class ReplayVerificationError(Exception):
51
51
pass
52
52
53
53
54
+ class ReplayConfigError (Exception ):
55
+ """Exception raised when replay configuration is invalid or missing."""
56
+
57
+ pass
58
+
59
+
54
60
class _InvocationReplayState (BaseModel ):
55
61
"""Per-invocation replay state to isolate concurrent runs."""
56
62
@@ -93,7 +99,7 @@ async def before_model_callback(
93
99
return None
94
100
95
101
if (state := self ._get_invocation_state (callback_context )) is None :
96
- raise ValueError (
102
+ raise ReplayConfigError (
97
103
"Replay state not initialized. Ensure before_run created it."
98
104
)
99
105
@@ -122,7 +128,7 @@ async def before_tool_callback(
122
128
return None
123
129
124
130
if (state := self ._get_invocation_state (tool_context )) is None :
125
- raise ValueError (
131
+ raise ReplayConfigError (
126
132
"Replay state not initialized. Ensure before_run created it."
127
133
)
128
134
@@ -188,20 +194,24 @@ def _load_invocation_state(
188
194
msg_index = config .get ("user_message_index" )
189
195
190
196
if not case_dir or msg_index is None :
191
- raise ValueError ("Replay parameters are missing from session state" )
197
+ raise ReplayConfigError (
198
+ "Replay parameters are missing from session state"
199
+ )
192
200
193
201
# Load recordings
194
202
recordings_file = Path (case_dir ) / "generated-recordings.yaml"
195
203
196
204
if not recordings_file .exists ():
197
- raise ValueError (f"Recordings file not found: { recordings_file } " )
205
+ raise ReplayConfigError (f"Recordings file not found: { recordings_file } " )
198
206
199
207
try :
200
208
with recordings_file .open ("r" , encoding = "utf-8" ) as f :
201
209
recordings_data = yaml .safe_load (f )
202
210
recordings = Recordings .model_validate (recordings_data )
203
211
except Exception as e :
204
- raise ValueError (f"Failed to load recordings from { recordings_file } : { e } " )
212
+ raise ReplayConfigError (
213
+ f"Failed to load recordings from { recordings_file } : { e } "
214
+ ) from e
205
215
206
216
# Load and store invocation state
207
217
state = _InvocationReplayState (
@@ -320,62 +330,28 @@ def _verify_llm_request_match(
320
330
agent_index : int ,
321
331
) -> None :
322
332
"""Verify that the current LLM request exactly matches the recorded one."""
323
- self ._verify_config_match (
324
- recorded_request , current_request , agent_name , agent_index
325
- )
326
- handled_fields : set [str ] = {"config" }
327
- ignored_fields : set [str ] = {"live_connect_config" }
328
- exclude_fields = handled_fields | ignored_fields
329
- if not self ._compare_fields (
330
- recorded_request , current_request , exclude_fields = exclude_fields
331
- ):
332
- raise ValueError (
333
- f"LLM request mismatch for agent '{ agent_name } ' (index"
334
- f" { agent_index } ): "
335
- "recorded:"
336
- f" { recorded_request .model_dump (exclude_none = True , exclude = exclude_fields )} ,"
337
- " current:"
338
- f" { current_request .model_dump (exclude_none = True , exclude = exclude_fields )} "
339
- )
340
-
341
- def _compare_fields (
342
- self ,
343
- obj1 : BaseModel ,
344
- obj2 : BaseModel ,
345
- * ,
346
- exclude_fields : Optional [set [str ]] = None ,
347
- ) -> bool :
348
- """Compare two Pydantic models excluding specified fields."""
349
- exclude_fields = exclude_fields or set ()
350
- dict1 = obj1 .model_dump (exclude_none = True , exclude = exclude_fields )
351
- dict2 = obj2 .model_dump (exclude_none = True , exclude = exclude_fields )
352
- return dict1 == dict2
353
-
354
- def _verify_config_match (
355
- self ,
356
- recorded_request : LlmRequest ,
357
- current_request : LlmRequest ,
358
- agent_name : str ,
359
- agent_index : int ,
360
- ) -> None :
361
- """Verify that the config matches between recorded and current requests."""
362
- # Fields to ignore when comparing GenerateContentConfig (denylist approach)
363
- ignored_fields : set [str ] = {
364
- "http_options" ,
365
- "labels" ,
333
+ # Comprehensive exclude dict for all fields that can differ between runs
334
+ excluded_fields = {
335
+ "live_connect_config" : True ,
336
+ "config" : { # some config fields can vary per run
337
+ "http_options" : True ,
338
+ "labels" : True ,
339
+ },
366
340
}
367
341
368
- if not self ._compare_fields (
369
- recorded_request .config ,
370
- current_request .config ,
371
- exclude_fields = ignored_fields ,
372
- ):
373
- raise ValueError (
374
- f"Config mismatch for agent '{ agent_name } ' (index { agent_index } ): "
375
- "recorded:"
376
- f" { recorded_request .config .model_dump (exclude_none = True , exclude = ignored_fields )} ,"
377
- " current:"
378
- f" { current_request .config .model_dump (exclude_none = True , exclude = ignored_fields )} "
342
+ # Compare using model dumps with nested exclude dict
343
+ recorded_dict = recorded_request .model_dump (
344
+ exclude_none = True , exclude = excluded_fields , exclude_defaults = True
345
+ )
346
+ current_dict = current_request .model_dump (
347
+ exclude_none = True , exclude = excluded_fields , exclude_defaults = True
348
+ )
349
+
350
+ if recorded_dict != current_dict :
351
+ raise ReplayVerificationError (
352
+ f"""LLM request mismatch for agent '{ agent_name } ' (index { agent_index } ):
353
+ recorded: { recorded_dict }
354
+ current: { current_dict } """
379
355
)
380
356
381
357
def _verify_tool_call_match (
@@ -389,12 +365,14 @@ def _verify_tool_call_match(
389
365
"""Verify that the current tool call exactly matches the recorded one."""
390
366
if recorded_call .name != tool_name :
391
367
raise ReplayVerificationError (
392
- f"Tool name mismatch for agent '{ agent_name } ' at index { agent_index } :"
393
- f" recorded='{ recorded_call .name } ', current='{ tool_name } '"
368
+ f"""Tool name mismatch for agent '{ agent_name } ' at index { agent_index } :
369
+ recorded: '{ recorded_call .name } '
370
+ current: '{ tool_name } '"""
394
371
)
395
372
396
373
if recorded_call .args != tool_args :
397
374
raise ReplayVerificationError (
398
- f"Tool args mismatch for agent '{ agent_name } ' at index { agent_index } :"
399
- f" recorded={ recorded_call .args } , current={ tool_args } "
375
+ f"""Tool args mismatch for agent '{ agent_name } ' at index { agent_index } :
376
+ recorded: { recorded_call .args }
377
+ current: { tool_args } """
400
378
)
0 commit comments