Skip to content

Commit 296e31c

Browse files
authored
feat: Add Type Validation parameter for Pipeline Connections (#8875)
* Starting to refactor type util tests to be more systematic * refactoring * Expand tests * Update to type utils * Add missing subclass check * Expand and refactor tests, introduce type_validation Literal * More test refactoring * Test refactoring, adding type validation variable to pipeline base * Update relaxed version of type checking to pass all newly added tests * trim whitespace * Add tests * cleanup * Updates docstrings * Add reno * docs * Fix mypy and add docstrings * Changes based on advice from Tobi * Remove unused imports * Doc strings * Add connection type validation to to_dict and from_dict * Update tests * Fix test * Also save connection_type_validation at global pipeline level * Fix tests * Remove connection type validation from the connect level, only keep at pipeline level * Formatting * Fix tests * formatting
1 parent 00fe4d1 commit 296e31c

File tree

13 files changed

+94
-46
lines changed

13 files changed

+94
-46
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Some examples of what you can do with Haystack:
6868
6969
> [!TIP]
7070
>
71-
> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/).
71+
> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/).
7272
7373
## 🆕 deepset Studio: Your Development Environment for Haystack
7474

haystack/core/pipeline/base.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ class PipelineBase:
6868
Builds a graph of components and orchestrates their execution according to the execution graph.
6969
"""
7070

71-
def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100):
71+
def __init__(
72+
self,
73+
metadata: Optional[Dict[str, Any]] = None,
74+
max_runs_per_component: int = 100,
75+
connection_type_validation: bool = True,
76+
):
7277
"""
7378
Creates the Pipeline.
7479
@@ -79,12 +84,15 @@ def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_compo
7984
How many times the `Pipeline` can run the same Component.
8085
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
8186
If not set defaults to 100 runs per Component.
87+
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
88+
Defaults to True.
8289
"""
8390
self._telemetry_runs = 0
8491
self._last_telemetry_sent: Optional[datetime] = None
8592
self.metadata = metadata or {}
8693
self.graph = networkx.MultiDiGraph()
8794
self._max_runs_per_component = max_runs_per_component
95+
self._connection_type_validation = connection_type_validation
8896

8997
def __eq__(self, other) -> bool:
9098
"""
@@ -142,6 +150,7 @@ def to_dict(self) -> Dict[str, Any]:
142150
"max_runs_per_component": self._max_runs_per_component,
143151
"components": components,
144152
"connections": connections,
153+
"connection_type_validation": self._connection_type_validation,
145154
}
146155

147156
@classmethod
@@ -164,7 +173,12 @@ def from_dict(
164173
data_copy = deepcopy(data) # to prevent modification of original data
165174
metadata = data_copy.get("metadata", {})
166175
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
167-
pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component)
176+
connection_type_validation = data_copy.get("connection_type_validation", True)
177+
pipe = cls(
178+
metadata=metadata,
179+
max_runs_per_component=max_runs_per_component,
180+
connection_type_validation=connection_type_validation,
181+
)
168182
components_to_reuse = kwargs.get("components", {})
169183
for name, component_data in data_copy.get("components", {}).items():
170184
if name in components_to_reuse:
@@ -402,6 +416,8 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
402416
:param receiver:
403417
The component that receives the value. This can be either just a component name or can be
404418
in the format `component_name.connection_name` if the component has multiple inputs.
419+
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
420+
Defaults to the value set in the pipeline.
405421
:returns:
406422
The Pipeline instance.
407423
@@ -418,48 +434,51 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
418434

419435
# Get the nodes data.
420436
try:
421-
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
437+
sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
422438
except KeyError as exc:
423439
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
424440
try:
425-
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
441+
receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
426442
except KeyError as exc:
427443
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
428444

429445
# If the name of either socket is given, get the socket
430446
sender_socket: Optional[OutputSocket] = None
431447
if sender_socket_name:
432-
sender_socket = from_sockets.get(sender_socket_name)
448+
sender_socket = sender_sockets.get(sender_socket_name)
433449
if not sender_socket:
434450
raise PipelineConnectError(
435451
f"'{sender} does not exist. "
436452
f"Output connections of {sender_component_name} are: "
437-
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
453+
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
438454
)
439455

440456
receiver_socket: Optional[InputSocket] = None
441457
if receiver_socket_name:
442-
receiver_socket = to_sockets.get(receiver_socket_name)
458+
receiver_socket = receiver_sockets.get(receiver_socket_name)
443459
if not receiver_socket:
444460
raise PipelineConnectError(
445461
f"'{receiver} does not exist. "
446462
f"Input connections of {receiver_component_name} are: "
447-
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
463+
+ ", ".join(
464+
[f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
465+
)
448466
)
449467

450468
# Look for a matching connection among the possible ones.
451469
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
452-
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
470+
sender_socket_candidates: List[OutputSocket] = (
471+
[sender_socket] if sender_socket else list(sender_sockets.values())
472+
)
453473
receiver_socket_candidates: List[InputSocket] = (
454-
[receiver_socket] if receiver_socket else list(to_sockets.values())
474+
[receiver_socket] if receiver_socket else list(receiver_sockets.values())
455475
)
456476

457477
# Find all possible connections between these two components
458-
possible_connections = [
459-
(sender_sock, receiver_sock)
460-
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
461-
if _types_are_compatible(sender_sock.type, receiver_sock.type)
462-
]
478+
possible_connections = []
479+
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
480+
if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
481+
possible_connections.append((sender_sock, receiver_sock))
463482

464483
# We need this status for error messages, since we might need it in multiple places we calculate it here
465484
status = _connections_status(
@@ -860,7 +879,7 @@ def from_template(
860879

861880
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
862881
"""
863-
Utility function to find all Components that receive input form `component_name`.
882+
Utility function to find all Components that receive input from `component_name`.
864883
865884
:param component_name:
866885
Name of the sender Component
@@ -1179,7 +1198,7 @@ def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
11791198

11801199
def _connections_status(
11811200
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
1182-
):
1201+
) -> str:
11831202
"""
11841203
Lists the status of the sockets, for error messages.
11851204
"""

haystack/core/type_utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,42 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Union, get_args, get_origin
5+
from typing import Any, TypeVar, Union, get_args, get_origin
66

77
from haystack import logging
88

99
logger = logging.getLogger(__name__)
1010

11+
T = TypeVar("T")
1112

12-
def _is_optional(type_: type) -> bool:
13+
14+
def _types_are_compatible(sender, receiver, type_validation: bool = True) -> bool:
1315
"""
14-
Utility method that returns whether a type is Optional.
16+
Determines if two types are compatible based on the specified validation mode.
17+
18+
:param sender: The sender type.
19+
:param receiver: The receiver type.
20+
:param type_validation: Whether to perform strict type validation.
21+
:return: True if the types are compatible, False otherwise.
1522
"""
16-
return get_origin(type_) is Union and type(None) in get_args(type_)
23+
if type_validation:
24+
return _strict_types_are_compatible(sender, receiver)
25+
else:
26+
return True
1727

1828

19-
def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
29+
def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
2030
"""
21-
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections.
31+
Checks whether the sender type is equal to or a subtype of the receiver type under strict validation.
2232
2333
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
2434
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
2535
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
26-
2736
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
37+
38+
:param sender: The sender type.
39+
:param receiver: The receiver type.
40+
:return: True if the sender type is strictly compatible with the receiver type, False otherwise.
2841
"""
2942
if sender == receiver or receiver is Any:
3043
return True
@@ -42,17 +55,19 @@ def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-
4255
receiver_origin = get_origin(receiver)
4356

4457
if sender_origin is not Union and receiver_origin is Union:
45-
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
58+
return any(_strict_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
4659

47-
if not sender_origin or not receiver_origin or sender_origin != receiver_origin:
60+
# Both must have origins and they must be equal
61+
if not (sender_origin and receiver_origin and sender_origin == receiver_origin):
4862
return False
4963

64+
# Compare generic type arguments
5065
sender_args = get_args(sender)
5166
receiver_args = get_args(receiver)
5267
if len(sender_args) > len(receiver_args):
5368
return False
5469

55-
return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
70+
return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
5671

5772

5873
def _type_name(type_):
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
We've introduced a new type_validation parameter to control type compatibility checks in pipeline connections.
5+
It can be set to True (default) or False which means no type checks will be done and everything is allowed.

test/components/connectors/test_openapi_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def test_serde_in_pipeline(self, monkeypatch):
151151
assert pipeline_dict == {
152152
"metadata": {},
153153
"max_runs_per_component": 100,
154+
"connection_type_validation": True,
154155
"components": {
155156
"api": {
156157
"type": "haystack.components.connectors.openapi.OpenAPIConnector",

test/components/connectors/test_openapi_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_serde_in_pipeline(self):
218218
assert pipeline_dict == {
219219
"metadata": {},
220220
"max_runs_per_component": 100,
221+
"connection_type_validation": True,
221222
"components": {
222223
"connector": {
223224
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",

test/components/generators/chat/test_azure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
171171
assert p.to_dict() == {
172172
"metadata": {},
173173
"max_runs_per_component": 100,
174+
"connection_type_validation": True,
174175
"components": {
175176
"generator": {
176177
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def test_serde_in_pipeline(self, mock_check_valid_model):
270270
assert pipeline_dict == {
271271
"metadata": {},
272272
"max_runs_per_component": 100,
273+
"connection_type_validation": True,
273274
"components": {
274275
"generator": {
275276
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",

test/components/routers/test_file_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def test_serde_in_pipeline(self):
354354
assert pipeline_dict == {
355355
"metadata": {},
356356
"max_runs_per_component": 100,
357+
"connection_type_validation": True,
357358
"components": {
358359
"file_type_router": {
359360
"type": "haystack.components.routers.file_type_router.FileTypeRouter",

test/components/tools/test_tool_invoker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def test_serde_in_pipeline(self, invoker, monkeypatch):
232232
assert pipeline_dict == {
233233
"metadata": {},
234234
"max_runs_per_component": 100,
235+
"connection_type_validation": True,
235236
"components": {
236237
"invoker": {
237238
"type": "haystack.components.tools.tool_invoker.ToolInvoker",

0 commit comments

Comments
 (0)