Skip to content

Commit 310588f

Browse files
committed
feat: add prompt saving functionality and improve error handling in DebugSaver
1 parent 79f1c80 commit 310588f

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

examples/deepsearch/deepsearch/debug.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def __init__(
5353
self._replay_threads: Dict[str, threading.Thread] = {}
5454
self._lock = threading.RLock()
5555

56+
def save_prompt(self, session_id: str, prompt: str):
57+
"""Save prompt for replay."""
58+
if self.mode == "replay":
59+
return
60+
61+
with open(self.output_dir / f"{session_id}.prompt", "w") as fp:
62+
fp.write(prompt)
63+
5664
def capture(self, func_name: str = None):
5765
# In replay mode, return a no-op decorator
5866
if self.mode == "replay":
@@ -88,9 +96,13 @@ def wrapper(*args, **kwargs):
8896
"timestamp": datetime.now().isoformat(),
8997
}
9098

91-
# Write the pickled call data
92-
pickle.dump(call_data, self.fp)
93-
self.fp.flush()
99+
try:
100+
# Write the pickled call data
101+
pickle.dump(call_data, self.fp)
102+
self.fp.flush()
103+
logging.debug(f"Captured call to {func_name}: {call_data}")
104+
except Exception as e:
105+
logging.error(f"Failed to capture call data: {e}")
94106

95107
return func(*args, **kwargs)
96108

@@ -246,6 +258,14 @@ def load_replay_session(self, session_id: str):
246258

247259
def load_replays(self):
248260
"""Load all replay data into memory."""
261+
262+
if not any(self.output_dir.glob("*.prompt")):
263+
logging.warning(f"No replay files found in {self.output_dir}")
264+
return
265+
249266
for session_file in self.output_dir.glob("*.prompt"):
250267
session_id = session_file.stem
251-
self.load_replay_session(session_id)
268+
try:
269+
self.load_replay_session(session_id)
270+
except Exception as e:
271+
logging.error(f"Failed to load replay session {session_id}: {e}")

examples/deepsearch/deepsearch/deepsearch.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def worker():
115115
return worker_thread
116116

117117

118-
@debug_saver.capture("notify") if debug_saver else lambda x: x
119118
def notify(metadata, message: Response):
120119
"""Callback to receive notifications from the graph."""
121120
global task_queue, session_manager
@@ -272,14 +271,16 @@ def handle_message(data):
272271

273272
# If in replay mode, trigger replay with current session
274273
global debug_saver
275-
if debug_saver and debug_saver.mode == "replay":
276-
debug_saver.start_replay_session(message.strip(), current_metadata)
277-
return
274+
if debug_saver:
275+
if debug_saver.mode == "replay":
276+
debug_saver.start_replay_session(message.strip(), current_metadata)
277+
return
278+
else:
279+
debug_saver.save_prompt(session_id, message.strip())
278280

279281
# Update session timestamp on activity
280282
session_manager.update_session_timestamp(session_id)
281283

282-
@debug_saver.capture("notify_planai") if debug_saver else lambda x: x
283284
def wrapped_notify_planai(*args, **kwargs):
284285
return notify_planai(*args, **kwargs)
285286

@@ -295,7 +296,6 @@ def wrapped_notify_planai(*args, **kwargs):
295296
session_metadata["provenance"] = provenance
296297

297298

298-
@debug_saver.capture("notify_planai") if debug_saver else lambda x: x
299299
def notify_planai(
300300
metadata: Dict[str, Any],
301301
prefix: ProvenanceChain,
@@ -312,7 +312,7 @@ def notify_planai(
312312
global session_manager
313313
# get the metadata for this session
314314
session_metadata = session_manager.metadata(session_id)
315-
if session_metadata.get("started"):
315+
if session_metadata.get("started") and False: # ignore for now
316316
# this indicates that we failed the task
317317
task_queue.put(
318318
(
@@ -360,6 +360,14 @@ def notify_planai(
360360
)
361361

362362

363+
def patch_notify_functions():
364+
"""Patch the notify functions with debug_saver decorators."""
365+
global notify, notify_planai, debug_saver
366+
if debug_saver:
367+
notify = debug_saver.capture("notify")(notify)
368+
notify_planai = debug_saver.capture("notify_planai")(notify_planai)
369+
370+
363371
def main():
364372
import argparse
365373

@@ -386,6 +394,9 @@ def main():
386394
replay_delay=args.replay_delay,
387395
)
388396

397+
# Patch the notify functions after debug_saver is initialized
398+
patch_notify_functions()
399+
389400
if args.replay:
390401
# Register the original functions for replay
391402
debug_saver.register_replay_handler("notify", notify)

0 commit comments

Comments
 (0)