Skip to content

Commit ae3e2df

Browse files
committed
Improve typing
1 parent b515963 commit ae3e2df

File tree

1 file changed

+9
-8
lines changed
  • pydantic_graph/pydantic_graph/beta

1 file changed

+9
-8
lines changed

pydantic_graph/pydantic_graph/beta/paths.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DepsT = TypeVar('DepsT', infer_variance=True)
2424
OutputT = TypeVar('OutputT', infer_variance=True)
2525
InputT = TypeVar('InputT', infer_variance=True)
26+
T = TypeVar('T')
2627

2728
if TYPE_CHECKING:
2829
from pydantic_graph.beta.node_types import AnyDestinationNode, DestinationNode, SourceNode
@@ -226,7 +227,7 @@ def fork(self, forks: Sequence[Path], /, *, fork_id: str | None = None) -> Path:
226227
next_item = BroadcastMarker(paths=forks, fork_id=ForkID(NodeID(fork_id or 'broadcast_' + secrets.token_hex(8))))
227228
return Path(items=[*self.working_items, next_item])
228229

229-
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> PathBuilder[StateT, DepsT, Any]:
230+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> PathBuilder[StateT, DepsT, T]:
230231
"""Add a transformation step to the path.
231232
232233
Args:
@@ -236,14 +237,14 @@ def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) ->
236237
A new PathBuilder with the transformation added
237238
"""
238239
next_item = TransformMarker(func)
239-
return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item])
240+
return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
240241

241242
def map(
242-
self: PathBuilder[StateT, DepsT, Iterable[Any]],
243+
self: PathBuilder[StateT, DepsT, Iterable[T]],
243244
*,
244245
fork_id: ForkID | None = None,
245246
downstream_join_id: JoinID | None = None,
246-
) -> PathBuilder[StateT, DepsT, Any]:
247+
) -> PathBuilder[StateT, DepsT, T]:
247248
"""Spread iterable data across parallel execution paths.
248249
249250
This method can only be called when the current output type is iterable.
@@ -259,7 +260,7 @@ def map(
259260
next_item = MapMarker(
260261
fork_id=NodeID(fork_id or 'map_' + secrets.token_hex(8)), downstream_join_id=downstream_join_id
261262
)
262-
return PathBuilder[StateT, DepsT, Any](working_items=[*self.working_items, next_item])
263+
return PathBuilder[StateT, DepsT, T](working_items=[*self.working_items, next_item])
263264

264265
def label(self, label: str, /) -> PathBuilder[StateT, DepsT, OutputT]:
265266
"""Add a human-readable label to this point in the path.
@@ -396,11 +397,11 @@ def to(
396397
)
397398

398399
def map(
399-
self: EdgePathBuilder[StateT, DepsT, Iterable[Any]],
400+
self: EdgePathBuilder[StateT, DepsT, Iterable[T]],
400401
*,
401402
fork_id: ForkID | None = None,
402403
downstream_join_id: JoinID | None = None,
403-
) -> EdgePathBuilder[StateT, DepsT, Any]:
404+
) -> EdgePathBuilder[StateT, DepsT, T]:
404405
"""Spread iterable data across parallel execution paths.
405406
406407
Args:
@@ -415,7 +416,7 @@ def map(
415416
path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
416417
)
417418

418-
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, Any], /) -> EdgePathBuilder[StateT, DepsT, Any]:
419+
def transform(self, func: TransformFunction[StateT, DepsT, OutputT, T], /) -> EdgePathBuilder[StateT, DepsT, T]:
419420
"""Add a transformation step to the edge path.
420421
421422
Args:

0 commit comments

Comments
 (0)