Skip to content

Commit fe60c76

Browse files
authored
refactor: Refactor _save_pipeline_snapshot and _create_pipeline_snapshot to handle more exceptions (#9871)
* Refactor saving pipeline snapshot to handle the try-except inside and to cover more cases (e.g. try-excepts around our serialization logic) * Add reno * Fix * Adding tests * More tests * small change * fix test * update docstrings
1 parent 0cd297a commit fe60c76

File tree

5 files changed

+223
-105
lines changed

5 files changed

+223
-105
lines changed

haystack/core/pipeline/breakpoint.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -144,117 +144,131 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot:
144144
return pipeline_snapshot
145145

146146

147-
def _save_pipeline_snapshot_to_file(
148-
*, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime
149-
) -> None:
147+
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot, raise_on_failure: bool = True) -> None:
150148
"""
151149
Save the pipeline snapshot dictionary to a JSON file.
152150
151+
- The filename is generated based on the component name, visit count, and timestamp.
152+
- The component name is taken from the break point's `component_name`.
153+
- The visit count is taken from the pipeline state's `component_visits` for the component name.
154+
- The timestamp is taken from the pipeline snapshot's `timestamp` or the current time if not available.
155+
- The file path is taken from the break point's `snapshot_file_path`.
156+
- If the `snapshot_file_path` is None, the function will return without saving.
157+
153158
:param pipeline_snapshot: The pipeline snapshot to save.
154-
:param snapshot_file_path: The path where to save the file.
155-
:param dt: The datetime object for timestamping.
159+
:param raise_on_failure: If True, raises an exception if saving fails. If False, logs the error and returns.
160+
156161
:raises:
157-
ValueError: If the snapshot_file_path is not a string or a Path object.
158162
Exception: If saving the JSON snapshot fails.
159163
"""
160-
snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path
161-
if not isinstance(snapshot_file_path, Path):
162-
raise ValueError("Debug path must be a string or a Path object.")
164+
break_point = pipeline_snapshot.break_point
165+
snapshot_file_path = (
166+
break_point.break_point.snapshot_file_path
167+
if isinstance(break_point, AgentBreakpoint)
168+
else break_point.snapshot_file_path
169+
)
163170

164-
snapshot_file_path.mkdir(exist_ok=True)
171+
if snapshot_file_path is None:
172+
return
173+
174+
dt = pipeline_snapshot.timestamp or datetime.now()
175+
snapshot_dir = Path(snapshot_file_path)
165176

166177
# Generate filename
167178
# We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
168-
if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
169-
agent_name = pipeline_snapshot.break_point.agent_name
170-
component_name = pipeline_snapshot.break_point.break_point.component_name
171-
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
172-
file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
179+
if isinstance(break_point, AgentBreakpoint):
180+
agent_name = break_point.agent_name
181+
component_name = break_point.break_point.component_name
173182
else:
174-
component_name = pipeline_snapshot.break_point.component_name
175-
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
176-
file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
183+
component_name = break_point.component_name
184+
agent_name = None
185+
186+
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
187+
timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S")
188+
file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json"
189+
full_path = snapshot_dir / file_name
177190

178191
try:
179-
with open(snapshot_file_path / file_name, "w") as f_out:
192+
snapshot_dir.mkdir(parents=True, exist_ok=True)
193+
with open(full_path, "w") as f_out:
180194
json.dump(pipeline_snapshot.to_dict(), f_out, indent=2)
181-
logger.info(f"Pipeline snapshot saved at: {file_name}")
182-
except Exception as e:
183-
logger.error(f"Failed to save pipeline snapshot: {str(e)}")
184-
raise
195+
logger.info(
196+
"Pipeline snapshot saved to '{full_path}'. You can use this file to debug or resume the pipeline.",
197+
full_path=full_path,
198+
)
199+
except Exception as error:
200+
logger.error("Failed to save pipeline snapshot to '{full_path}'. Error: {e}", full_path=full_path, e=error)
201+
if raise_on_failure:
202+
raise
185203

186204

187205
def _create_pipeline_snapshot(
188206
*,
189207
inputs: dict[str, Any],
208+
component_inputs: dict[str, Any],
190209
break_point: Union[AgentBreakpoint, Breakpoint],
191210
component_visits: dict[str, int],
192-
original_input_data: Optional[dict[str, Any]] = None,
193-
ordered_component_names: Optional[list[str]] = None,
194-
include_outputs_from: Optional[set[str]] = None,
195-
pipeline_outputs: Optional[dict[str, Any]] = None,
211+
original_input_data: dict[str, Any],
212+
ordered_component_names: list[str],
213+
include_outputs_from: set[str],
214+
pipeline_outputs: dict[str, Any],
196215
) -> PipelineSnapshot:
197216
"""
198217
Create a snapshot of the pipeline at the point where the breakpoint was triggered.
199218
200219
:param inputs: The current pipeline snapshot inputs.
220+
:param component_inputs: The inputs to the component that triggered the breakpoint.
201221
:param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
202222
:param component_visits: The visit count of the component that triggered the breakpoint.
203223
:param original_input_data: The original input data.
204224
:param ordered_component_names: The ordered component names.
205225
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
226+
:param pipeline_outputs: The current outputs of the pipeline.
227+
:returns:
228+
A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
206229
"""
207-
dt = datetime.now()
230+
if isinstance(break_point, AgentBreakpoint):
231+
component_name = break_point.agent_name
232+
else:
233+
component_name = break_point.component_name
208234

209235
transformed_original_input_data = _transform_json_structure(original_input_data)
210-
transformed_inputs = _transform_json_structure(inputs)
236+
transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs})
237+
238+
try:
239+
serialized_inputs = _serialize_value_with_schema(transformed_inputs)
240+
except Exception as error:
241+
logger.warning(
242+
"Failed to serialize the inputs of the current pipeline state. "
243+
"The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}",
244+
e=error,
245+
)
246+
serialized_inputs = {}
247+
248+
try:
249+
serialized_original_input_data = _serialize_value_with_schema(transformed_original_input_data)
250+
except Exception as error:
251+
logger.warning(
252+
"Failed to serialize original input data for `pipeline.run`. "
253+
"This likely occurred due to non-serializable object types. "
254+
"The snapshot will store an empty dictionary instead. Error: {e}",
255+
e=error,
256+
)
257+
serialized_original_input_data = {}
211258

212259
pipeline_snapshot = PipelineSnapshot(
213260
pipeline_state=PipelineState(
214-
inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs
215-
component_visits=component_visits,
216-
pipeline_outputs=pipeline_outputs or {},
261+
inputs=serialized_inputs, component_visits=component_visits, pipeline_outputs=pipeline_outputs
217262
),
218-
timestamp=dt,
263+
timestamp=datetime.now(),
219264
break_point=break_point,
220-
original_input_data=_serialize_value_with_schema(transformed_original_input_data),
221-
ordered_component_names=ordered_component_names or [],
222-
include_outputs_from=include_outputs_from or set(),
265+
original_input_data=serialized_original_input_data,
266+
ordered_component_names=ordered_component_names,
267+
include_outputs_from=include_outputs_from,
223268
)
224269
return pipeline_snapshot
225270

226271

227-
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot:
228-
"""
229-
Save the pipeline snapshot to a file.
230-
231-
:param pipeline_snapshot: The pipeline snapshot to save.
232-
233-
:returns:
234-
The dictionary containing the snapshot of the pipeline containing the following keys:
235-
- input_data: The original input data passed to the pipeline.
236-
- timestamp: The timestamp of the breakpoint.
237-
- pipeline_breakpoint: The component name and visit count that triggered the breakpoint.
238-
- pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys:
239-
- inputs: The current state of inputs for pipeline components.
240-
- component_visits: The visit count of the components when the breakpoint was triggered.
241-
- ordered_component_names: The order of components in the pipeline.
242-
"""
243-
break_point = pipeline_snapshot.break_point
244-
if isinstance(break_point, AgentBreakpoint):
245-
snapshot_file_path = break_point.break_point.snapshot_file_path
246-
else:
247-
snapshot_file_path = break_point.snapshot_file_path
248-
249-
if snapshot_file_path is not None:
250-
dt = pipeline_snapshot.timestamp or datetime.now()
251-
_save_pipeline_snapshot_to_file(
252-
pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt
253-
)
254-
255-
return pipeline_snapshot
256-
257-
258272
def _transform_json_structure(data: Union[dict[str, Any], list[Any], Any]) -> Any:
259273
"""
260274
Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level.

haystack/core/pipeline/pipeline.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,9 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
360360
and component_name == break_point.agent_name
361361
)
362362
if break_point and (component_break_point_triggered or agent_break_point_triggered):
363-
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
364-
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
365363
new_pipeline_snapshot = _create_pipeline_snapshot(
366-
inputs=pipeline_snapshot_inputs_serialised,
364+
inputs=deepcopy(inputs),
365+
component_inputs=deepcopy(component_inputs),
367366
break_point=break_point,
368367
component_visits=component_visits,
369368
original_input_data=data,
@@ -378,7 +377,7 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
378377
component_inputs["break_point"] = break_point
379378
component_inputs["parent_snapshot"] = new_pipeline_snapshot
380379

381-
# trigger the breakpoint if needed
380+
# trigger the break point if needed
382381
if component_break_point_triggered:
383382
_trigger_break_point(pipeline_snapshot=new_pipeline_snapshot)
384383

@@ -400,11 +399,10 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
400399
snapshot_file_path=out_dir,
401400
)
402401

403-
# Create a snapshot of the last good state of the pipeline before the error occurred.
404-
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
405-
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
406-
last_good_state_snapshot = _create_pipeline_snapshot(
407-
inputs=pipeline_snapshot_inputs_serialised,
402+
# Create a snapshot of the state of the pipeline before the error occurred.
403+
pipeline_snapshot = _create_pipeline_snapshot(
404+
inputs=deepcopy(inputs),
405+
component_inputs=deepcopy(component_inputs),
408406
break_point=break_point,
409407
component_visits=component_visits,
410408
original_input_data=data,
@@ -417,23 +415,12 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
417415
# We take the agent snapshot and attach it to the pipeline snapshot we create here.
418416
# We also update the break_point to be an AgentBreakpoint.
419417
if error.pipeline_snapshot and error.pipeline_snapshot.agent_snapshot:
420-
last_good_state_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
421-
last_good_state_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
422-
423-
# Attach the last good state snapshot to the error before re-raising it and saving to disk
424-
error.pipeline_snapshot = last_good_state_snapshot
425-
426-
try:
427-
_save_pipeline_snapshot(pipeline_snapshot=last_good_state_snapshot)
428-
logger.info(
429-
"Saved a snapshot of the pipeline's last valid state to '{out_path}'. "
430-
"Review this snapshot to debug the error and resume the pipeline from here.",
431-
out_path=out_dir,
432-
)
433-
except Exception as save_error:
434-
logger.error(
435-
"Failed to save a snapshot of the pipeline's last valid state with error: {e}", e=save_error
436-
)
418+
pipeline_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
419+
pipeline_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
420+
421+
# Attach the pipeline snapshot to the error before re-raising
422+
error.pipeline_snapshot = pipeline_snapshot
423+
_save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot, raise_on_failure=False)
437424
raise error
438425

439426
# Updates global input state with component outputs and returns outputs that should go to

haystack/dataclasses/breakpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,12 @@ class PipelineSnapshot:
189189
"""
190190
A dataclass to hold a snapshot of the pipeline at a specific point in time.
191191
192+
:param original_input_data: The original input data provided to the pipeline.
193+
:param ordered_component_names: A list of component names in the order they were visited.
192194
:param pipeline_state: The state of the pipeline at the time of the snapshot.
193195
:param break_point: The breakpoint that triggered the snapshot.
194196
:param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint.
195197
:param timestamp: A timestamp indicating when the snapshot was taken.
196-
:param original_input_data: The original input data provided to the pipeline.
197-
:param ordered_component_names: A list of component names in the order they were visited.
198198
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
199199
"""
200200

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
Refactored `_save_pipeline_snapshot` to consolidate try-except logic and added a `raise_on_failure` option to control whether save failures raise an exception or are logged.
5+
`_create_pipeline_snapshot` now wraps `_serialize_value_with_schema` in try-except blocks to prevent failures from non-serializable pipeline inputs.

0 commit comments

Comments
 (0)