Skip to content

Commit 40b8f9a

Browse files
committed
Reuse task IDs when possible
1 parent 99afd08 commit 40b8f9a

File tree

1 file changed

+8
-13
lines changed
  • pydantic_graph/pydantic_graph/beta

1 file changed

+8
-13
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,16 @@ class GraphTask(GraphTaskRequest):
334334
and has a unique ID to identify the task within the graph run.
335335
"""
336336

337-
node_id: NodeID
338-
"""The ID of the node to execute."""
339-
340-
inputs: Any
341-
"""The input data for the node."""
342-
343-
fork_stack: ForkStack = field(repr=False)
344-
"""Stack of forks that have been entered.
345-
346-
Used by the GraphRun to decide when to proceed through joins.
347-
"""
348-
349337
task_id: TaskID
350338
"""Unique identifier for this task."""
351339

340+
@staticmethod
341+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342+
# Don't call the get_task_id callable, this is already a task
343+
if isinstance(request, GraphTask):
344+
return request
345+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346+
352347

353348
class GraphRun(Generic[StateT, DepsT, OutputT]):
354349
"""A single execution instance of a graph.
@@ -498,7 +493,7 @@ async def next(
498493
if isinstance(value, EndMarker):
499494
self._next = value
500495
else:
501-
self._next = [GraphTask(gt.node_id, gt.inputs, gt.fork_stack, self._get_next_task_id()) for gt in value]
496+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
502497
return await anext(self)
503498

504499
@property

0 commit comments

Comments
 (0)