@@ -334,21 +334,16 @@ class GraphTask(GraphTaskRequest):
334334 and has a unique ID to identify the task within the graph run.
335335 """
336336
337- node_id : NodeID
338- """The ID of the node to execute."""
339-
340- inputs : Any
341- """The input data for the node."""
342-
343- fork_stack : ForkStack = field (repr = False )
344- """Stack of forks that have been entered.
345-
346- Used by the GraphRun to decide when to proceed through joins.
347- """
348-
349337 task_id : TaskID
350338 """Unique identifier for this task."""
351339
340+ @staticmethod
341+ def from_request (request : GraphTaskRequest , get_task_id : Callable [[], TaskID ]) -> GraphTask :
342+ # Don't call the get_task_id callable, this is already a task
343+ if isinstance (request , GraphTask ):
344+ return request
345+ return GraphTask (request .node_id , request .inputs , request .fork_stack , get_task_id ())
346+
352347
353348class GraphRun (Generic [StateT , DepsT , OutputT ]):
354349 """A single execution instance of a graph.
@@ -498,7 +493,7 @@ async def next(
498493 if isinstance (value , EndMarker ):
499494 self ._next = value
500495 else :
501- self ._next = [GraphTask ( gt . node_id , gt . inputs , gt . fork_stack , self ._get_next_task_id ()) for gt in value ]
496+ self ._next = [GraphTask . from_request ( gtr , self ._get_next_task_id ) for gtr in value ]
502497 return await anext (self )
503498
504499 @property
0 commit comments