Skip to content

Commit ee81570

Browse files
fix: only overwrite existing socket inputs when we provide a new value (#8940)
* fix: only overwrite existing socket inputs when we provide a new value * chore: add release notes * Apply suggestions from code review --------- Co-authored-by: Julian Risch <[email protected]>
1 parent db4f237 commit ee81570

File tree

5 files changed

+201
-11
lines changed

5 files changed

+201
-11
lines changed

haystack/core/pipeline/base.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,18 +1114,26 @@ def _write_component_outputs(
11141114
if receiver_name not in inputs:
11151115
inputs[receiver_name] = {}
11161116

1117-
# If we have a non-variadic or a greedy variadic receiver socket, we can just overwrite any inputs
1118-
# that might already exist (to be reconsidered but mirrors current behavior).
1119-
if not is_socket_lazy_variadic(receiver_socket):
1120-
inputs[receiver_name][receiver_socket.name] = [{"sender": component_name, "value": value}]
1121-
1122-
# If the receiver socket is lazy variadic, and it already has an input, we need to append the new input.
1123-
# Lazy variadic sockets can collect multiple inputs.
1117+
if is_socket_lazy_variadic(receiver_socket):
1118+
# If the receiver socket is lazy variadic, we append the new input.
1119+
# Lazy variadic sockets can collect multiple inputs.
1120+
_write_to_lazy_variadic_socket(
1121+
inputs=inputs,
1122+
receiver_name=receiver_name,
1123+
receiver_socket_name=receiver_socket.name,
1124+
component_name=component_name,
1125+
value=value,
1126+
)
11241127
else:
1125-
if not inputs[receiver_name].get(receiver_socket.name):
1126-
inputs[receiver_name][receiver_socket.name] = []
1127-
1128-
inputs[receiver_name][receiver_socket.name].append({"sender": component_name, "value": value})
1128+
# If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic.
1129+
# We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None.
1130+
_write_to_standard_socket(
1131+
inputs=inputs,
1132+
receiver_name=receiver_name,
1133+
receiver_socket_name=receiver_socket.name,
1134+
component_name=component_name,
1135+
value=value,
1136+
)
11291137

11301138
# If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
11311139
# outputs
@@ -1192,3 +1200,35 @@ def _connections_status(
11921200
receiver_sockets_list = "\n".join(receiver_sockets_entries)
11931201

11941202
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
1203+
1204+
1205+
# Utility functions for writing to sockets
1206+
1207+
1208+
def _write_to_lazy_variadic_socket(
1209+
inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1210+
) -> None:
1211+
"""
1212+
Write to a lazy variadic socket.
1213+
1214+
Mutates inputs in place.
1215+
"""
1216+
if not inputs[receiver_name].get(receiver_socket_name):
1217+
inputs[receiver_name][receiver_socket_name] = []
1218+
1219+
inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value})
1220+
1221+
1222+
def _write_to_standard_socket(
1223+
inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1224+
) -> None:
1225+
"""
1226+
Write to a greedy variadic or non-variadic socket.
1227+
1228+
Mutates inputs in place.
1229+
"""
1230+
current_value = inputs[receiver_name].get(receiver_socket_name)
1231+
1232+
# Only overwrite if there's no existing value, or we have a new value to provide
1233+
if current_value is None or value is not _NO_OUTPUT_PRODUCED:
1234+
inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Fixes an edge case in the pipeline-run logic where an existing input could be overwritten if the same component
5+
connects to the socket from multiple output sockets.

test/core/pipeline/features/pipeline_run.feature

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Feature: Pipeline running
5555
| that is a file conversion pipeline with three joiners |
5656
| that is a file conversion pipeline with three joiners and a loop |
5757
| that has components returning dataframes |
58+
| where a single component connects multiple sockets to the same receiver socket |
59+
| where a component in a cycle provides inputs for a component outside the cycle in one iteration and no input in another iteration |
5860

5961
Scenario Outline: Running a bad Pipeline
6062
Given a pipeline <kind>

test/core/pipeline/features/test_run.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5392,3 +5392,131 @@ def run(self, dataframe: pd.DataFrame) -> Dict[str, Any]:
53925392
)
53935393
],
53945394
)
5395+
5396+
5397+
@given(
5398+
"a pipeline where a single component connects multiple sockets to the same receiver socket",
5399+
target_fixture="pipeline_data",
5400+
)
5401+
def pipeline_single_component_many_sockets_same_target(pipeline_class):
5402+
joiner = BranchJoiner(type_=str)
5403+
5404+
routes = [
5405+
{"condition": "{{query == 'route_1'}}", "output": "{{query}}", "output_name": "route_1", "output_type": str},
5406+
{"condition": "{{query == 'route_2'}}", "output": "{{query}}", "output_name": "route_2", "output_type": str},
5407+
]
5408+
5409+
router = ConditionalRouter(routes=routes)
5410+
5411+
pp = pipeline_class(max_runs_per_component=1)
5412+
5413+
pp.add_component("joiner", joiner)
5414+
pp.add_component("router", router)
5415+
5416+
pp.connect("router.route_1", "joiner.value")
5417+
pp.connect("router.route_2", "joiner.value")
5418+
5419+
return (
5420+
pp,
5421+
[
5422+
PipelineRunData(
5423+
inputs={"router": {"query": "route_1"}},
5424+
expected_outputs={"joiner": {"value": "route_1"}},
5425+
expected_component_calls={("router", 1): {"query": "route_1"}, ("joiner", 1): {"value": ["route_1"]}},
5426+
)
5427+
],
5428+
)
5429+
5430+
5431+
@given(
5432+
"a pipeline where a component in a cycle provides inputs for a component outside the cycle in one iteration and no input in another iteration",
5433+
target_fixture="pipeline_data",
5434+
)
5435+
def pipeline_component_cycle_input_no_input(pipeline_class):
5436+
joiner = BranchJoiner(type_=str)
5437+
5438+
routes = [
5439+
{
5440+
"condition": "{{query == 'iterationiterationiterationiteration'}}",
5441+
"output": "{{query}}",
5442+
"output_name": "exit",
5443+
"output_type": str,
5444+
},
5445+
{
5446+
"condition": "{{query != 'iterationiterationiterationiteration'}}",
5447+
"output": "{{query}}",
5448+
"output_name": "continue",
5449+
"output_type": str,
5450+
},
5451+
]
5452+
5453+
template = "{{query ~ query}}"
5454+
5455+
builder = PromptBuilder(template=template)
5456+
5457+
router = ConditionalRouter(routes=routes)
5458+
5459+
outside_builder = PromptBuilder(
5460+
template="{{cycle_output ~ delayed_input}}", required_variables=["cycle_output", "delayed_input"]
5461+
)
5462+
5463+
outside_routes = [
5464+
{
5465+
"condition": "{{query == 'iterationiteration'}}",
5466+
"output": "{{query}}",
5467+
"output_name": "cycle_output",
5468+
"output_type": str,
5469+
},
5470+
{
5471+
"condition": "{{query != 'iterationiteration'}}",
5472+
"output": "{{query}}",
5473+
"output_name": "no_output",
5474+
"output_type": str,
5475+
},
5476+
]
5477+
5478+
outside_router = ConditionalRouter(routes=outside_routes)
5479+
5480+
pp = pipeline_class(max_runs_per_component=1)
5481+
5482+
pp.add_component("joiner", joiner)
5483+
pp.add_component("router", router)
5484+
pp.add_component("builder", builder)
5485+
pp.add_component("outside_builder", outside_builder)
5486+
pp.add_component("outside_router", outside_router)
5487+
5488+
pp.connect("joiner.value", "builder.query")
5489+
pp.connect("builder.prompt", "router.query")
5490+
pp.connect("router.continue", "joiner.value")
5491+
pp.connect("builder.prompt", "outside_router.query")
5492+
pp.connect("outside_router.cycle_output", "outside_builder.cycle_output")
5493+
pp.connect("router.exit", "outside_builder.delayed_input")
5494+
5495+
return (
5496+
pp,
5497+
[
5498+
PipelineRunData(
5499+
inputs={"joiner": {"value": "iteration"}},
5500+
expected_outputs={
5501+
"outside_builder": {"prompt": "iterationiterationiterationiterationiterationiteration"},
5502+
"outside_router": {"no_output": "iterationiterationiterationiteration"},
5503+
},
5504+
expected_component_calls={
5505+
("joiner", 1): {"value": ["iteration"]},
5506+
("builder", 1): {"query": "iteration", "template": None, "template_variables": None},
5507+
("router", 1): {"query": "iterationiteration"},
5508+
("outside_router", 1): {"query": "iterationiteration"},
5509+
("joiner", 2): {"value": ["iterationiteration"]},
5510+
("builder", 2): {"query": "iterationiteration", "template": None, "template_variables": None},
5511+
("router", 2): {"query": "iterationiterationiterationiteration"},
5512+
("outside_router", 2): {"query": "iterationiterationiterationiteration"},
5513+
("outside_builder", 1): {
5514+
"cycle_output": "iterationiteration",
5515+
"template": None,
5516+
"template_variables": None,
5517+
"delayed_input": "iterationiterationiterationiteration",
5518+
},
5519+
},
5520+
)
5521+
],
5522+
)

test/core/pipeline/test_pipeline_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,21 @@ def test__write_component_outputs_different_output_values(
14171417

14181418
assert inputs["receiver1"]["input1"] == [{"sender": "sender1", "value": output_value}]
14191419

1420+
def test__write_component_outputs_dont_overwrite_with_no_output(self, regular_output_socket, regular_input_socket):
1421+
"""Test that existing inputs are not overwritten with _NO_OUTPUT_PRODUCED"""
1422+
receivers = [("receiver1", regular_output_socket, regular_input_socket)]
1423+
component_outputs = {"output1": _NO_OUTPUT_PRODUCED}
1424+
inputs = {"receiver1": {"input1": [{"sender": "sender1", "value": "keep"}]}}
1425+
PipelineBase._write_component_outputs(
1426+
component_name="sender1",
1427+
component_outputs=component_outputs,
1428+
inputs=inputs,
1429+
receivers=receivers,
1430+
include_outputs_from=[],
1431+
)
1432+
1433+
assert inputs["receiver1"]["input1"] == [{"sender": "sender1", "value": "keep"}]
1434+
14201435
@pytest.mark.parametrize("receivers_count", [1, 2, 3], ids=["single-receiver", "two-receivers", "three-receivers"])
14211436
def test__write_component_outputs_multiple_receivers(
14221437
self, receivers_count, regular_output_socket, regular_input_socket

0 commit comments

Comments
 (0)