Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class ComponentPriority(IntEnum):
BLOCKED = 5


class PipelineBase: # noqa: PLW1641
class PipelineBase:
__hash__ = None
"""
Components orchestration engine.

Expand Down Expand Up @@ -899,7 +900,48 @@ def _create_component_span(
parent_span=parent_span,
)

def validate_input(self, data: Dict[str, Any]) -> None:
def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any]) -> None:
"""
Validates input data for a specific component.

:param component_name: Name of the component.
:param component_inputs: Inputs provided for the component.
:param data: All pipeline input data.
:raises ValueError: If inputs are invalid.
"""
if component_name not in self.graph.nodes:
available_nodes_message = f"Available components: {list(self.graph.nodes.keys())}"
raise ValueError(
f"Component '{component_name}' not found in the pipeline. "
f"{available_nodes_message}"
)
instance = self.graph.nodes[component_name]["instance"]

# Validate that all mandatory inputs are provided either directly or by senders
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.is_mandatory and not socket.senders and socket_name not in component_inputs:
raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name}'.")

# Validate that provided inputs exist in the component's input sockets
for input_name in component_inputs.keys():
if input_name not in instance.__haystack_input__._sockets_dict:
available_inputs_message = f"Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}"
raise ValueError(
f"Unexpected input '{input_name}' for component '{component_name}'. "
f"{available_inputs_message}"
)

# Validate that inputs are not multiply defined (already sent by another component and also provided directly)
# unless the socket is variadic
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
raise ValueError(
f"Input '{socket_name}' for component '{component_name}' is already provided by component "
f"'{socket.senders[0]}'. Do not provide it directly."
)


def _validate_input(self, data: Dict[str, Any]) -> None:
"""
Validates pipeline input data.

Expand All @@ -916,26 +958,23 @@ def validate_input(self, data: Dict[str, Any]) -> None:
If inputs are invalid according to the above.
"""
for component_name, component_inputs in data.items():
if component_name not in self.graph.nodes:
raise ValueError(f"Component named {component_name} not found in the pipeline.")
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
for input_name in component_inputs.keys():
if input_name not in instance.__haystack_input__._sockets_dict:
raise ValueError(f"Input {input_name} not found in component {component_name}.")

for component_name in self.graph.nodes:
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
component_inputs = data.get(component_name, {})
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
raise ValueError(
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
)
self._validate_component_input(component_name, component_inputs, data)

# Additionally, check for components that might be missing inputs,
# even if they were not explicitly mentioned in the `data` dictionary.
# This covers cases where a component has mandatory inputs but receives no data.
for component_name_in_graph in self.graph.nodes:
if component_name_in_graph not in data:
# This component was not in the input data dictionary, check if it has mandatory inputs without senders
instance = self.graph.nodes[component_name_in_graph]["instance"]
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.is_mandatory and not socket.senders:
error_message = (
f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' "
"(not found in input data)."
)
raise ValueError(error_message)


def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/refactor-validate-input.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
- Refactore the PipelineBase._validate_input() method to improve clarity and maintainability.
- break down the method into smaller helper functions and enhance error messages for better specificity.
Loading
Loading