diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 2db9cdb1ee..a2289334b0 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -71,7 +71,8 @@ class ComponentPriority(IntEnum): BLOCKED = 5 -class PipelineBase: # noqa: PLW1641 +class PipelineBase: + __hash__ = None """ Components orchestration engine. @@ -899,7 +900,48 @@ def _create_component_span( parent_span=parent_span, ) - def validate_input(self, data: Dict[str, Any]) -> None: + def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any]) -> None: + """ + Validates input data for a specific component. + + :param component_name: Name of the component. + :param component_inputs: Inputs provided for the component. + :param data: All pipeline input data. + :raises ValueError: If inputs are invalid. + """ + if component_name not in self.graph.nodes: + available_nodes_message = f"Available components: {list(self.graph.nodes.keys())}" + raise ValueError( + f"Component '{component_name}' not found in the pipeline. " + f"{available_nodes_message}" + ) + instance = self.graph.nodes[component_name]["instance"] + + # Validate that all mandatory inputs are provided either directly or by senders + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.is_mandatory and not socket.senders and socket_name not in component_inputs: + raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name}'.") + + # Validate that provided inputs exist in the component's input sockets + for input_name in component_inputs.keys(): + if input_name not in instance.__haystack_input__._sockets_dict: + available_inputs_message = f"Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}" + raise ValueError( + f"Unexpected input '{input_name}' for component '{component_name}'. " + f"{available_inputs_message}" + ) + + # Validate that inputs are not multiply defined (already sent by another component and also provided directly) + # unless the socket is variadic + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.senders and socket_name in component_inputs and not socket.is_variadic: + raise ValueError( + f"Input '{socket_name}' for component '{component_name}' is already provided by component " + f"'{socket.senders[0]}'. Do not provide it directly." + ) + + + def _validate_input(self, data: Dict[str, Any]) -> None: """ Validates pipeline input data. @@ -916,26 +958,23 @@ def validate_input(self, data: Dict[str, Any]) -> None: If inputs are invalid according to the above. """ for component_name, component_inputs in data.items(): - if component_name not in self.graph.nodes: - raise ValueError(f"Component named {component_name} not found in the pipeline.") - instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: - raise ValueError(f"Missing input for component {component_name}: {socket_name}") - for input_name in component_inputs.keys(): - if input_name not in instance.__haystack_input__._sockets_dict: - raise ValueError(f"Input {input_name} not found in component {component_name}.") - - for component_name in self.graph.nodes: - instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - component_inputs = data.get(component_name, {}) - if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: - raise ValueError(f"Missing input for component {component_name}: {socket_name}") - if socket.senders and socket_name in component_inputs and not socket.is_variadic: - raise ValueError( - f"Input {socket_name} for component {component_name} is already sent by {socket.senders}." - ) + self._validate_component_input(component_name, component_inputs, data) + + # Additionally, check for components that might be missing inputs, + # even if they were not explicitly mentioned in the `data` dictionary. + # This covers cases where a component has mandatory inputs but receives no data. + for component_name_in_graph in self.graph.nodes: + if component_name_in_graph not in data: + # This component was not in the input data dictionary, check if it has mandatory inputs without senders + instance = self.graph.nodes[component_name_in_graph]["instance"] + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.is_mandatory and not socket.senders: + error_message = ( + f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' " + "(not found in input data)." + ) + raise ValueError(error_message) + def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ diff --git a/releasenotes/notes/refactor-validate-input.yaml b/releasenotes/notes/refactor-validate-input.yaml new file mode 100644 index 0000000000..335f5c3bcd --- /dev/null +++ b/releasenotes/notes/refactor-validate-input.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + - Refactore the PipelineBase._validate_input() method to improve clarity and maintainability. + - break down the method into smaller helper functions and enhance error messages for better specificity. \ No newline at end of file diff --git a/test/core/pipeline/test_pipeline_base.py b/test/core/pipeline/test_pipeline_base.py index 37fc40c890..2d60fb44c1 100644 --- a/test/core/pipeline/test_pipeline_base.py +++ b/test/core/pipeline/test_pipeline_base.py @@ -175,6 +175,110 @@ def test_draw(self, mock_to_mermaid_image, tmp_path): pipe.draw(path=image_path) assert image_path.read_bytes() == mock_to_mermaid_image.return_value + def test_find_super_components(self): + """ + Test that the pipeline can find super components in it's pipeline. + """ + from haystack import Pipeline + from haystack.components.converters import MultiFileConverter + from haystack.components.preprocessors import DocumentPreprocessor + from haystack.components.writers import DocumentWriter + from haystack.document_stores.in_memory import InMemoryDocumentStore + + multi_file_converter = MultiFileConverter() + doc_processor = DocumentPreprocessor() + + pipeline = Pipeline() + pipeline.add_component("converter", multi_file_converter) + pipeline.add_component("preprocessor", doc_processor) + pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) + pipeline.connect("converter", "preprocessor") + pipeline.connect("preprocessor", "writer") + + result = pipeline._find_super_components() + + assert len(result) == 2 + assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result + + def test_merge_super_component_pipelines(self): + from haystack import Pipeline + from haystack.components.converters import MultiFileConverter + from haystack.components.preprocessors import DocumentPreprocessor + from haystack.components.writers import DocumentWriter + from haystack.document_stores.in_memory import InMemoryDocumentStore + + multi_file_converter = MultiFileConverter() + doc_processor = DocumentPreprocessor() + + pipeline = Pipeline() + pipeline.add_component("converter", multi_file_converter) + pipeline.add_component("preprocessor", doc_processor) + pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) + pipeline.connect("converter", "preprocessor") + pipeline.connect("preprocessor", "writer") + + merged_graph, super_component_components = pipeline._merge_super_component_pipelines() + + assert super_component_components == { + "router": "converter", + "docx": "converter", + "html": "converter", + "json": "converter", + "md": "converter", + "text": "converter", + "pdf": "converter", + "pptx": "converter", + "xlsx": "converter", + "joiner": "converter", + "csv": "converter", + "splitter": "preprocessor", + "cleaner": "preprocessor", + } + + expected_nodes = [ + "cleaner", + "csv", + "docx", + "html", + "joiner", + "json", + "md", + "pdf", + "pptx", + "router", + "splitter", + "text", + "writer", + "xlsx", + ] + assert sorted(merged_graph.nodes) == expected_nodes + + expected_edges = [ + ("cleaner", "writer"), + ("csv", "joiner"), + ("docx", "joiner"), + ("html", "joiner"), + ("joiner", "splitter"), + ("json", "joiner"), + ("md", "joiner"), + ("pdf", "joiner"), + ("pptx", "joiner"), + ("router", "csv"), + ("router", "docx"), + ("router", "html"), + ("router", "json"), + ("router", "md"), + ("router", "pdf"), + ("router", "pptx"), + ("router", "text"), + ("router", "xlsx"), + ("splitter", "cleaner"), + ("text", "joiner"), + ("xlsx", "joiner"), + ] + actual_edges = [(u, v) for u, v, _ in merged_graph.edges] + assert sorted(actual_edges) == expected_edges + # UNIT def test_add_invalid_component_name(self): pipe = PipelineBase() @@ -1683,7 +1787,7 @@ def test__consume_component_inputs_with_df(self, regular_input_socket): assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]})) @patch("haystack.core.pipeline.draw.requests") - def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests, tmp_path): + def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests): """ Test that calling the pipeline draw method with positional arguments raises a warning. """ @@ -1694,7 +1798,7 @@ def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock mock_response = mock_requests.get.return_value mock_response.status_code = 200 mock_response.content = b"image_data" - out_file = tmp_path / "original_pipeline.png" + out_file = Path("original_pipeline.png") with warnings.catch_warnings(record=True) as w: pipeline.draw(out_file, server_url="http://localhost:3000") assert len(w) == 1 @@ -1728,7 +1832,7 @@ def test_pipeline_show_called_with_positional_args_triggers_a_warning(self, mock ) @patch("haystack.core.pipeline.draw.requests") - def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests, tmp_path): + def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests): """ Test that calling the pipeline draw method with keyword arguments does not raise a warning. """ @@ -1739,7 +1843,7 @@ def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_r mock_response = mock_requests.get.return_value mock_response.status_code = 200 mock_response.content = b"image_data" - out_file = tmp_path / "original_pipeline.png" + out_file = Path("original_pipeline.png") with warnings.catch_warnings(record=True) as w: pipeline.draw(path=out_file, server_url="http://localhost:3000") @@ -1763,108 +1867,60 @@ def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_i pipeline.show(server_url="http://localhost:3000") assert len(w) == 0, "No warning should be triggered when using keyword arguments" - @pytest.mark.integration - def test_find_super_components(self): - """ - Test that the pipeline can find super components in it's pipeline. - """ - from haystack import Pipeline - from haystack.components.converters import MultiFileConverter - from haystack.components.preprocessors import DocumentPreprocessor - from haystack.components.writers import DocumentWriter - from haystack.document_stores.in_memory import InMemoryDocumentStore - - multi_file_converter = MultiFileConverter() - doc_processor = DocumentPreprocessor() - pipeline = Pipeline() - pipeline.add_component("converter", multi_file_converter) - pipeline.add_component("preprocessor", doc_processor) - pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) - pipeline.connect("converter", "preprocessor") - pipeline.connect("preprocessor", "writer") - - result = pipeline._find_super_components() - - assert len(result) == 2 - assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result +class TestValidateInput: + def test_validate_input_valid_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + pipe._validate_input(data={"comp1": {"x": 1}}) + # No exception should be raised - @pytest.mark.integration - def test_merge_super_component_pipelines(self): - from haystack import Pipeline - from haystack.components.converters import MultiFileConverter - from haystack.components.preprocessors import DocumentPreprocessor - from haystack.components.writers import DocumentWriter - from haystack.document_stores.in_memory import InMemoryDocumentStore + def test_validate_input_missing_mandatory_input(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Missing mandatory input 'x' for component 'comp1'"): + pipe._validate_input(data={"comp1": {}}) - multi_file_converter = MultiFileConverter() - doc_processor = DocumentPreprocessor() + def test_validate_input_missing_mandatory_input_for_component_not_in_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": str}, output_types={"b": str})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) # comp2 requires 'a' but is not in data + with pytest.raises(ValueError, match="Missing mandatory input 'a' for component 'comp2' which was not provided in the input data."): + pipe._validate_input(data={"comp1": {"x": 1}}) - pipeline = Pipeline() - pipeline.add_component("converter", multi_file_converter) - pipeline.add_component("preprocessor", doc_processor) - pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) - pipeline.connect("converter", "preprocessor") - pipeline.connect("preprocessor", "writer") - merged_graph, super_component_components = pipeline._merge_super_component_pipelines() + def test_validate_input_to_already_connected_socket(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": int}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + with pytest.raises(ValueError, match="Input 'a' for component 'comp2' is already provided by component 'comp1'. Do not provide it directly."): + pipe._validate_input(data={"comp2": {"a": 1}}) - assert super_component_components == { - "router": "converter", - "docx": "converter", - "html": "converter", - "json": "converter", - "md": "converter", - "text": "converter", - "pdf": "converter", - "pptx": "converter", - "xlsx": "converter", - "joiner": "converter", - "csv": "converter", - "splitter": "preprocessor", - "cleaner": "preprocessor", - } + def test_validate_input_for_non_existent_component(self): + pipe = PipelineBase() + with pytest.raises(ValueError, match="Component 'non_existent' not found in the pipeline. Available components: \\[\\]"): + pipe._validate_input(data={"non_existent": {"x": 1}}) - expected_nodes = [ - "cleaner", - "csv", - "docx", - "html", - "joiner", - "json", - "md", - "pdf", - "pptx", - "router", - "splitter", - "text", - "writer", - "xlsx", - ] - assert sorted(merged_graph.nodes) == expected_nodes + def test_validate_input_with_unexpected_input_name(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Unexpected input 'z' for component 'comp1'. Available inputs: \\['x'\\]"): + pipe._validate_input(data={"comp1": {"z": 1}}) - expected_edges = [ - ("cleaner", "writer"), - ("csv", "joiner"), - ("docx", "joiner"), - ("html", "joiner"), - ("joiner", "splitter"), - ("json", "joiner"), - ("md", "joiner"), - ("pdf", "joiner"), - ("pptx", "joiner"), - ("router", "csv"), - ("router", "docx"), - ("router", "html"), - ("router", "json"), - ("router", "md"), - ("router", "pdf"), - ("router", "pptx"), - ("router", "text"), - ("router", "xlsx"), - ("splitter", "cleaner"), - ("text", "joiner"), - ("xlsx", "joiner"), - ] - actual_edges = [(u, v) for u, v, _ in merged_graph.edges] - assert sorted(actual_edges) == expected_edges + def test_validate_input_variadic_socket_can_receive_multiple_inputs(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": Variadic[int]}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + # Should not raise an error, as variadic sockets can accept multiple inputs + pipe._validate_input(data={"comp2": {"a": 1}}) \ No newline at end of file