88from __future__ import annotations as _annotations
99
1010import sys
11- import uuid
12- from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable , Sequence
11+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Callable , Iterable , Sequence
1312from contextlib import AbstractContextManager , AsyncExitStack , ExitStack , asynccontextmanager , contextmanager
1413from dataclasses import dataclass , field
1514from typing import TYPE_CHECKING , Any , Generic , Literal , TypeGuard , cast , get_args , get_origin , overload
2221from pydantic_graph import exceptions
2322from pydantic_graph ._utils import AbstractSpan , get_traceparent , infer_obj_name , logfire_span
2423from pydantic_graph .beta .decision import Decision
25- from pydantic_graph .beta .id_types import ForkID , ForkStack , ForkStackItem , GraphRunID , JoinID , NodeID , NodeRunID , TaskID
24+ from pydantic_graph .beta .id_types import ForkID , ForkStack , ForkStackItem , JoinID , NodeID , NodeRunID , TaskID
2625from pydantic_graph .beta .join import Join , JoinNode , JoinState , ReducerContext
2726from pydantic_graph .beta .node import (
2827 EndNode ,
@@ -306,14 +305,13 @@ def __str__(self) -> str:
306305
307306
308307@dataclass
309- class GraphTask :
310- """A single task representing the execution of a node in the graph.
308+ class GraphTaskRequest :
309+ """A request to run a task representing the execution of a node in the graph.
311310
312- GraphTask encapsulates all the information needed to execute a specific
311+ GraphTaskRequest encapsulates all the information needed to execute a specific
313312 node, including its inputs and the fork context it's executing within.
314313 """
315314
316- # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317315 node_id : NodeID
318316 """The ID of the node to execute."""
319317
@@ -326,7 +324,29 @@ class GraphTask:
326324 Used by the GraphRun to decide when to proceed through joins.
327325 """
328326
329- task_id : TaskID = field (default_factory = lambda : TaskID (str (uuid .uuid4 ())), repr = False )
327+
328+ @dataclass
329+ class GraphTask (GraphTaskRequest ):
330+ """A task representing the execution of a node in the graph.
331+
332+ GraphTask encapsulates all the information needed to execute a specific
333+ node, including its inputs and the fork context it's executing within,
334+ and has a unique ID to identify the task within the graph run.
335+ """
336+
337+ node_id : NodeID
338+ """The ID of the node to execute."""
339+
340+ inputs : Any
341+ """The input data for the node."""
342+
343+ fork_stack : ForkStack = field (repr = False )
344+ """Stack of forks that have been entered.
345+
346+ Used by the GraphRun to decide when to proceed through joins.
347+ """
348+
349+ task_id : TaskID
330350 """Unique identifier for this task."""
331351
332352
@@ -378,12 +398,20 @@ def __init__(
378398 self ._next : EndMarker [OutputT ] | Sequence [GraphTask ] | None = None
379399 """The next item to be processed."""
380400
381- run_id = GraphRunID (str (uuid .uuid4 ()))
382- initial_fork_stack : ForkStack = (ForkStackItem (StartNode .id , NodeRunID (run_id ), 0 ),)
383- self ._first_task = GraphTask (node_id = StartNode .id , inputs = inputs , fork_stack = initial_fork_stack )
401+ self ._next_task_id = 0
402+ self ._next_node_run_id = 0
403+ initial_fork_stack : ForkStack = (ForkStackItem (StartNode .id , self ._get_next_node_run_id (), 0 ),)
404+ self ._first_task = GraphTask (
405+ node_id = StartNode .id , inputs = inputs , fork_stack = initial_fork_stack , task_id = self ._get_next_task_id ()
406+ )
384407 self ._iterator_task_group = create_task_group ()
385408 self ._iterator_instance = _GraphIterator [StateT , DepsT , OutputT ](
386- self .graph , self .state , self .deps , self ._iterator_task_group
409+ self .graph ,
410+ self .state ,
411+ self .deps ,
412+ self ._iterator_task_group ,
413+ self ._get_next_node_run_id ,
414+ self ._get_next_task_id ,
387415 )
388416 self ._iterator = self ._iterator_instance .iter_graph (self ._first_task )
389417
@@ -449,7 +477,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
449477 return self ._next
450478
451479 async def next (
452- self , value : EndMarker [OutputT ] | Sequence [GraphTask ] | None = None
480+ self , value : EndMarker [OutputT ] | Sequence [GraphTaskRequest ] | None = None
453481 ) -> EndMarker [OutputT ] | Sequence [GraphTask ]:
454482 """Advance the graph execution by one step.
455483
@@ -467,7 +495,10 @@ async def next(
467495 # if `next` is called before the `first_node` has run.
468496 await anext (self )
469497 if value is not None :
470- self ._next = value
498+ if isinstance (value , EndMarker ):
499+ self ._next = value
500+ else :
501+ self ._next = [GraphTask (gt .node_id , gt .inputs , gt .fork_stack , self ._get_next_task_id ()) for gt in value ]
471502 return await anext (self )
472503
473504 @property
@@ -490,6 +521,16 @@ def output(self) -> OutputT | None:
490521 return self ._next .value
491522 return None
492523
524+ def _get_next_task_id (self ) -> TaskID :
525+ next_id = TaskID (f'task:{ self ._next_task_id } ' )
526+ self ._next_task_id += 1
527+ return next_id
528+
529+ def _get_next_node_run_id (self ) -> NodeRunID :
530+ next_id = NodeRunID (f'task:{ self ._next_node_run_id } ' )
531+ self ._next_node_run_id += 1
532+ return next_id
533+
493534
494535@dataclass
495536class _GraphTaskAsyncIterable :
@@ -510,6 +551,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510551 state : StateT
511552 deps : DepsT
512553 task_group : TaskGroup
554+ get_next_node_run_id : Callable [[], NodeRunID ]
555+ get_next_task_id : Callable [[], TaskID ]
513556
514557 cancel_scopes : dict [TaskID , CancelScope ] = field (init = False )
515558 active_tasks : dict [TaskID , GraphTask ] = field (init = False )
@@ -522,6 +565,7 @@ def __post_init__(self):
522565 self .active_tasks = {}
523566 self .active_reducers = {}
524567 self .iter_stream_sender , self .iter_stream_receiver = create_memory_object_stream [_GraphTaskResult ]()
568+ self ._next_node_run_id = 1
525569
526570 async def iter_graph ( # noqa C901
527571 self , first_task : GraphTask
@@ -782,12 +826,12 @@ def _handle_node(
782826 fork_stack : ForkStack ,
783827 ) -> Sequence [GraphTask ] | JoinItem | EndMarker [OutputT ]:
784828 if isinstance (next_node , StepNode ):
785- return [GraphTask (next_node .step .id , next_node .inputs , fork_stack )]
829+ return [GraphTask (next_node .step .id , next_node .inputs , fork_stack , self . get_next_task_id () )]
786830 elif isinstance (next_node , JoinNode ):
787831 return JoinItem (next_node .join .id , next_node .inputs , fork_stack )
788832 elif isinstance (next_node , BaseNode ):
789833 node_step = NodeStep (next_node .__class__ )
790- return [GraphTask (node_step .id , next_node , fork_stack )]
834+ return [GraphTask (node_step .id , next_node , fork_stack , self . get_next_task_id () )]
791835 elif isinstance (next_node , End ):
792836 return EndMarker (next_node .data )
793837 else :
@@ -821,7 +865,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
821865 'These markers should be removed from paths during graph building'
822866 )
823867 if isinstance (item , DestinationMarker ):
824- return [GraphTask (item .destination_id , inputs , fork_stack )]
868+ return [GraphTask (item .destination_id , inputs , fork_stack , self . get_next_task_id () )]
825869 elif isinstance (item , TransformMarker ):
826870 inputs = item .transform (StepContext (state = self .state , deps = self .deps , inputs = inputs ))
827871 return self ._handle_path (path .next_path , inputs , fork_stack )
@@ -853,7 +897,7 @@ def _handle_fork_edges(
853897 ) # this should have already been ensured during graph building
854898
855899 new_tasks : list [GraphTask ] = []
856- node_run_id = NodeRunID ( str ( uuid . uuid4 ()) )
900+ node_run_id = self . get_next_node_run_id ( )
857901 if node .is_map :
858902 # If the map specifies a downstream join id, eagerly create a join state for it
859903 if (join_id := node .downstream_join_id ) is not None :
0 commit comments