Skip to content

Commit 3097319

Browse files
committed
fix: a bug where the SubGrapWorker did not correctly setup the graph and distances were not computed which let to failure when multiple JoinedTaskWorkers were waiting on the same provenance
1 parent 0f7bf55 commit 3097319

File tree

3 files changed

+90
-4
lines changed

3 files changed

+90
-4
lines changed

src/planai/graph_task.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def consume_work(self, task: output_type):
5050
old_task = task.get_private_state(PRIVATE_STATE_KEY)
5151
if old_task is None:
5252
raise ValueError(
53-
f"No task provenance found for {PRIVATE_STATE_KEY}"
53+
f"No task state found for {PRIVATE_STATE_KEY}: {task._private_state} "
54+
f"(provenance: {task._provenance})",
5455
)
5556
assert isinstance(task, Task)
5657

@@ -75,7 +76,7 @@ def consume_work(self, task: output_type):
7576
with graph_task.lock:
7677
if provenance not in graph_task._state:
7778
raise ValueError(
78-
f"Task {provenance} does not have any associated state."
79+
f"Task {provenance} does not have any associated state: {graph_task._state}"
7980
)
8081
logging.debug(
8182
"Subgraph is removing provenance for %s in %s",
@@ -92,6 +93,9 @@ def consume_work(self, task: output_type):
9293
instance = AdapterSinkWorker()
9394
self.graph.add_workers(instance)
9495
self.graph.set_dependency(self.exit_worker, instance)
96+
self.graph.set_entry(self.entry_worker)
97+
self.graph.finalize() # compute the worker distances
98+
self.graph.init_workers()
9599

96100
def get_task_class(self) -> Type[Task]:
97101
# usually the entry task gets dynamically determined from consume_work but we are overriding it here
@@ -101,7 +105,6 @@ def init(self):
101105
# we need to install the graph dispatcher into the sub-graph
102106
assert self._graph is not None
103107
self.graph._dispatcher = self._graph._dispatcher
104-
self.graph.init_workers()
105108

106109
def consume_work(self, task: Task):
107110
new_task = task.copy_public()
@@ -143,7 +146,9 @@ def notify(self, prefix: str):
143146
self.graph.unwatch(prefix, self)
144147
with self.lock:
145148
if prefix not in self._state:
146-
raise ValueError(f"Task {prefix} does not have any associated state.")
149+
raise ValueError(
150+
f"Task {prefix} does not have any associated state: {self._state}"
151+
)
147152
task, remove_provenance = self._state.pop(prefix)
148153

149154
if remove_provenance:

tests/planai/patterns/test_search_fetch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,42 @@ def test_search_fetch_graph_workflow(self):
223223
self.assertTrue(page.title.startswith("Test Result"))
224224
self.assertEqual(page.content, self.mock_page_content)
225225

226+
def test_search_fetch_worker_distances(self):
227+
# Create graph and get references to workers
228+
graph, search_executor, exit_worker = create_search_fetch_graph(
229+
llm=self.mock_llm, name="TestSearchFetch"
230+
)
231+
232+
# Compute distances
233+
graph.set_entry(search_executor)
234+
graph.finalize()
235+
236+
# Validate specific distances from InitialTaskWorker
237+
initial_distances = graph._worker_distances["InitialTaskWorker"]
238+
expected_order = {
239+
"SearchExecutor": 1,
240+
"SearchResultSplitter": 2,
241+
"PageFetcher": 3,
242+
"PageRelevanceFilter": 4,
243+
"PageAnalysisConsumer": 4,
244+
"PageConsolidator": 5,
245+
}
246+
for worker_name, expected_distance in expected_order.items():
247+
self.assertEqual(
248+
initial_distances[worker_name],
249+
expected_distance,
250+
f"Expected {worker_name} to be at distance {expected_distance}, "
251+
f"got {initial_distances[worker_name]}",
252+
)
253+
254+
# Validate specific distance between PageAnalysisConsumer and PageConsolidator
255+
analysis_consumer_distances = graph._worker_distances["PageAnalysisConsumer"]
256+
self.assertEqual(
257+
analysis_consumer_distances["PageConsolidator"],
258+
1,
259+
"Expected PageConsolidator to be at distance 1 from PageAnalysisConsumer",
260+
)
261+
226262

227263
if __name__ == "__main__":
228264
unittest.main()

tests/planai/test_graph_task.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,51 @@ def abort_thread():
558558
self.assertEqual(dispatcher.work_queue.qsize(), 0)
559559
self.assertEqual(dispatcher.active_tasks, 0)
560560

561+
def test_graph_task_entry_setup(self):
562+
# Create subgraph
563+
subgraph = Graph(name="SubGraph")
564+
subgraph_worker = SubGraphHandler()
565+
subgraph.add_workers(subgraph_worker)
566+
567+
# Create GraphTask
568+
_ = SubGraphWorker(
569+
graph=subgraph, entry_worker=subgraph_worker, exit_worker=subgraph_worker
570+
)
571+
572+
# Verify that entry worker has been set correctly
573+
self.assertIn(subgraph_worker, subgraph.dependencies[subgraph._initial_worker])
574+
self.assertEqual(len(subgraph.dependencies[subgraph._initial_worker]), 1)
575+
576+
def test_graph_task_distances(self):
577+
# Create subgraph with multiple workers
578+
subgraph = Graph(name="SubGraph")
579+
580+
class Worker1(TaskWorker):
581+
output_types: List[Type[Task]] = [SubGraphTask]
582+
583+
def consume_work(self, task: SubGraphTask):
584+
self.publish_work(task, input_task=task)
585+
586+
class Worker2(TaskWorker):
587+
output_types: List[Type[Task]] = [SubGraphTask]
588+
589+
def consume_work(self, task: SubGraphTask):
590+
self.publish_work(task, input_task=task)
591+
592+
worker1 = Worker1()
593+
worker2 = Worker2()
594+
subgraph.add_workers(worker1, worker2)
595+
subgraph.set_dependency(worker1, worker2)
596+
597+
# Create GraphTask
598+
_ = SubGraphWorker(graph=subgraph, entry_worker=worker1, exit_worker=worker2)
599+
600+
# Verify subgraph distances
601+
sub_distances = subgraph._worker_distances["InitialTaskWorker"]
602+
self.assertEqual(sub_distances["Worker1"], 1)
603+
self.assertEqual(sub_distances["Worker2"], 2)
604+
self.assertEqual(sub_distances["AdapterSinkWorker"], 3)
605+
561606

562607
class StatusNotifyingWorker(TaskWorker):
563608
output_types: List[Type[Task]] = [FinalTask]

0 commit comments

Comments
 (0)