Skip to content

Commit 56f1e5a

Browse files
committed
Make StepContext a dataclass again
1 parent 12e8883 commit 56f1e5a

File tree

2 files changed

+14
-48
lines changed

2 files changed

+14
-48
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
492492
reducer, _ = reducer_and_fork_stack
493493

494494
try:
495-
reducer.reduce(StepContext(self.state, self.deps, result.inputs))
495+
reducer.reduce(StepContext(state=self.state, deps=self.deps, inputs=result.inputs))
496496
except StopIteration:
497497
# cancel all concurrently running tasks with the same fork_run_id of the parent fork
498498
task_ids_to_cancel = set[TaskID]()
@@ -522,7 +522,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
522522

523523
for join_id, fork_run_id in self._get_completed_fork_runs(source_task, tasks_by_id.values()):
524524
reducer, fork_stack = self._active_reducers.pop((join_id, fork_run_id))
525-
output = reducer.finalize(StepContext(self.state, self.deps, None))
525+
output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None))
526526
join_node = self.graph.nodes[join_id]
527527
assert isinstance(
528528
join_node, Join
@@ -545,7 +545,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
545545
continue # this reducer is a strict prefix for one of the other active reducers
546546

547547
self._active_reducers.pop((join_id, fork_run_id)) # we're finalizing it now
548-
output = reducer.finalize(StepContext(self.state, self.deps, None))
548+
output = reducer.finalize(StepContext(state=self.state, deps=self.deps, inputs=None))
549549
join_node = self.graph.nodes[join_id]
550550
assert isinstance(join_node, Join) # We could drop this but if it fails it means there is a bug.
551551
new_tasks = self._handle_edges(join_node, output, fork_stack)
@@ -576,7 +576,7 @@ async def _handle_task(
576576
if self.graph.auto_instrument:
577577
stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node))
578578

579-
step_context = StepContext[StateT, DepsT, Any](state, deps, inputs)
579+
step_context = StepContext[StateT, DepsT, Any](state=state, deps=deps, inputs=inputs)
580580
output = await node.call(step_context)
581581
if isinstance(node, NodeStep):
582582
return self._handle_node(output, fork_stack)
@@ -684,7 +684,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
684684
elif isinstance(item, BroadcastMarker):
685685
return [GraphTask(item.fork_id, inputs, fork_stack)]
686686
elif isinstance(item, TransformMarker):
687-
inputs = item.transform(StepContext(self.state, self.deps, inputs))
687+
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
688688
return self._handle_path(path.next_path, inputs, fork_stack)
689689
elif isinstance(item, LabelMarker):
690690
return self._handle_path(path.next_path, inputs, fork_stack)

pydantic_graph/pydantic_graph/beta/step.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from __future__ import annotations
99

1010
from collections.abc import Awaitable
11-
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, get_origin, overload
11+
from dataclasses import dataclass, field
12+
from typing import Any, Generic, Protocol, cast, get_origin, overload
1313

1414
from typing_extensions import TypeVar
1515

@@ -22,6 +22,7 @@
2222
OutputT = TypeVar('OutputT', infer_variance=True)
2323

2424

25+
@dataclass(kw_only=True, frozen=True)
2526
class StepContext(Generic[StateT, DepsT, InputT]):
2627
"""Context information passed to step functions during graph execution.
2728
@@ -35,49 +36,14 @@ class StepContext(Generic[StateT, DepsT, InputT]):
3536
InputT: The type of the input data
3637
"""
3738

38-
if TYPE_CHECKING:
39+
state: StateT = field(repr=False) # exclude from repr to keep things concise
40+
"""The current graph state."""
3941

40-
def __init__(self, state: StateT, deps: DepsT, inputs: InputT):
41-
self._state = state
42-
self._deps = deps
43-
self._inputs = inputs
42+
deps: DepsT = field(repr=False) # exclude from repr to keep things concise
43+
"""The dependencies available to this step."""
4444

45-
@property
46-
def state(self) -> StateT:
47-
"""The current graph state."""
48-
return self._state
49-
50-
@property
51-
def deps(self) -> DepsT:
52-
"""The dependencies available to this step."""
53-
return self._deps
54-
55-
@property
56-
def inputs(self) -> InputT:
57-
"""The input data for this step."""
58-
return self._inputs
59-
else:
60-
state: StateT
61-
"""The current graph state."""
62-
63-
deps: DepsT
64-
"""The dependencies available to this step."""
65-
66-
inputs: InputT
67-
"""The input data for this step."""
68-
69-
def __repr__(self) -> str:
70-
"""Return a string representation of the step context.
71-
72-
Returns:
73-
A string showing the class name and inputs
74-
"""
75-
return f'{self.__class__.__name__}(inputs={self.inputs})'
76-
77-
78-
if not TYPE_CHECKING:
79-
# TODO: Try dropping inputs from StepContext, it would make for fewer generic params, could help
80-
StepContext = dataclass(StepContext)
45+
inputs: InputT
46+
"""The input data for this step."""
8147

8248

8349
class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]):

0 commit comments

Comments
 (0)