Skip to content

Commit 83e9312

Browse files
committed
WIP
1 parent 55a8d42 commit 83e9312

File tree

5 files changed

+94
-96
lines changed

5 files changed

+94
-96
lines changed

pydantic_graph/pydantic_graph/_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import types
55
import warnings
6+
import inspect
67
from collections.abc import Callable, Generator
78
from contextlib import contextmanager
89
from functools import partial
@@ -163,3 +164,41 @@ def logfire_span(*args: Any, **kwargs: Any) -> Generator[LogfireSpan, None, None
163164
warnings.filterwarnings('ignore', category=LogfireNotConfiguredWarning)
164165
with _logfire.span(*args, **kwargs) as span:
165166
yield span
167+
168+
169+
def infer_obj_name(obj: Any, *, depth: int) -> str | None:
170+
"""Infer the variable name of an object from the calling frame's scope.
171+
172+
This function examines the call stack to find what variable name was used
173+
for the given object in the calling scope. This is useful for automatic
174+
naming of objects based on their variable names.
175+
176+
Args:
177+
obj: The object whose variable name to infer.
178+
depth: Number of stack frames to traverse upward from the current frame.
179+
180+
Returns:
181+
The inferred variable name if found, None otherwise.
182+
183+
Example:
184+
Usage should generally look like `infer_name(self, depth=2)` or similar.
185+
"""
186+
target_frame = inspect.currentframe()
187+
if target_frame is None:
188+
return None # pragma: no cover
189+
for _ in range(depth):
190+
target_frame = target_frame.f_back
191+
if target_frame is None:
192+
return None
193+
194+
for name, item in target_frame.f_locals.items():
195+
if item is obj:
196+
return name
197+
198+
if target_frame.f_locals != target_frame.f_globals: # pragma: no branch
199+
# if we couldn't find the agent in locals and globals are a different dict, try globals
200+
for name, item in target_frame.f_globals.items():
201+
if item is obj:
202+
return name
203+
204+
return None

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from pydantic_ai.exceptions import ExceptionGroup
2424
from pydantic_graph import exceptions
25-
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span
25+
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span, infer_obj_name
2626
from pydantic_graph.beta.decision import Decision
2727
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
2828
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
@@ -187,7 +187,9 @@ async def run(
187187
The final output from the graph execution
188188
"""
189189
if infer_name and self.name is None:
190-
self._infer_name(inspect.currentframe())
190+
inferred_name = infer_obj_name(self, depth=2)
191+
if inferred_name is not None:
192+
self.name = inferred_name
191193

192194
async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run:
193195
# Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method,
@@ -602,7 +604,7 @@ async def iter_graph( # noqa C901
602604
for afs in active_fork_stacks
603605
):
604606
# this join_state is a strict prefix for one of the other active join_states
605-
continue # pragma: no cover # TODO: We should cover this
607+
continue # pragma: no cover # It's difficult to cover this
606608
self.active_reducers.pop(
607609
(join_id, fork_run_id)
608610
) # we're handling it now, so we can pop it

pydantic_graph/pydantic_graph/beta/util.py

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
including workarounds for type checker limitations and utilities for runtime type inspection.
55
"""
66

7-
import inspect
87
from dataclasses import dataclass
98
from typing import Any, Generic, cast, get_args, get_origin
109

@@ -17,16 +16,16 @@
1716
class TypeExpression(Generic[T]):
1817
"""A workaround for type checker limitations when using complex type expressions.
1918
20-
This class serves as a wrapper for types that cannot normally be used in positions
19+
This class serves as a wrapper for types that cannot normally be used in positions
2120
requiring `type[T]`, such as `Any`, `Union[...]`, or `Literal[...]`. It provides a
22-
way to pass these complex type expressions to functions expecting concrete types.
21+
way to pass these complex type expressions to functions expecting concrete types.
2322
24-
Example:
25-
Instead of `output_type=Union[str, int]` (which may cause type errors),
26-
use `output_type=TypeExpression[Union[str, int]]`.
23+
Example:
24+
Instead of `output_type=Union[str, int]` (which may cause type errors),
25+
use `output_type=TypeExpression[Union[str, int]]`.
2726
28-
Note:
29-
This is a workaround for the lack of TypeForm in the Python type system.
27+
Note:
28+
This is a workaround for the lack of TypeForm in the Python type system.
3029
"""
3130

3231
pass
@@ -89,44 +88,3 @@ def get_callable_name(callable_: Any) -> str:
8988
The callable's __name__ attribute if available, otherwise its string representation.
9089
"""
9190
return getattr(callable_, '__name__', str(callable_))
92-
93-
94-
def infer_name(obj: Any, *, depth: int) -> str | None:
95-
"""Infer the variable name of an object from the calling frame's scope.
96-
97-
This function examines the call stack to find what variable name was used
98-
for the given object in the calling scope. This is useful for automatic
99-
naming of objects based on their variable names.
100-
101-
Args:
102-
obj: The object whose variable name to infer.
103-
depth: Number of stack frames to traverse upward from the current frame.
104-
105-
Returns:
106-
The inferred variable name if found, None otherwise.
107-
108-
Example:
109-
Usage should generally look like `infer_name(self, depth=2)` or similar.
110-
111-
Note:
112-
TODO(P3): Use this or lose it
113-
"""
114-
target_frame = inspect.currentframe()
115-
if target_frame is None:
116-
return None # pragma: no cover
117-
for _ in range(depth):
118-
target_frame = target_frame.f_back
119-
if target_frame is None:
120-
return None
121-
122-
for name, item in target_frame.f_locals.items():
123-
if item is obj:
124-
return name
125-
126-
if target_frame.f_locals != target_frame.f_globals: # pragma: no branch
127-
# if we couldn't find the agent in locals and globals are a different dict, try globals
128-
for name, item in target_frame.f_globals.items():
129-
if item is obj:
130-
return name
131-
132-
return None

tests/graph/beta/test_util.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Some,
55
TypeExpression,
66
get_callable_name,
7-
infer_name,
87
unpack_type_expression,
98
)
109

@@ -48,45 +47,3 @@ class MyClass:
4847
name = get_callable_name(obj)
4948
assert isinstance(name, str)
5049
assert 'object' in name
51-
52-
53-
def test_infer_name():
54-
"""Test inferring variable names from the calling frame."""
55-
my_object = object()
56-
# Depth 1 means we look at the frame calling infer_name
57-
inferred = infer_name(my_object, depth=1)
58-
assert inferred == 'my_object'
59-
60-
# Test with object not in locals
61-
result = infer_name(object(), depth=1)
62-
assert result is None
63-
64-
65-
def test_infer_name_no_frame():
66-
"""Test infer_name when frame inspection fails."""
67-
# This is hard to trigger without mocking, but we can test that the function
68-
# returns None gracefully when it can't find the object
69-
some_obj = object()
70-
71-
# Call with depth that would exceed the call stack
72-
result = infer_name(some_obj, depth=1000)
73-
assert result is None
74-
75-
76-
global_obj = object()
77-
78-
79-
def test_infer_name_locals_vs_globals():
80-
"""Test infer_name prefers locals over globals."""
81-
result = infer_name(global_obj, depth=1)
82-
assert result == 'global_obj'
83-
84-
# Assign a local name to the variable and ensure it is found with precedence over the global
85-
local_obj = global_obj
86-
result = infer_name(global_obj, depth=1)
87-
assert result == 'local_obj'
88-
89-
# If we unbind the local name, should find the global name again
90-
del local_obj
91-
result = infer_name(global_obj, depth=1)
92-
assert result == 'global_obj'

tests/graph/test_utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from threading import Thread
22

3-
from pydantic_graph._utils import get_event_loop
3+
from pydantic_graph._utils import get_event_loop, infer_obj_name
44

55

66
def test_get_event_loop_in_thread():
@@ -11,3 +11,45 @@ def get_and_close_event_loop():
1111
thread = Thread(target=get_and_close_event_loop)
1212
thread.start()
1313
thread.join()
14+
15+
16+
def test_infer_obj_name():
17+
"""Test inferring variable names from the calling frame."""
18+
my_object = object()
19+
# Depth 1 means we look at the frame calling infer_obj_name
20+
inferred = infer_obj_name(my_object, depth=1)
21+
assert inferred == 'my_object'
22+
23+
# Test with object not in locals
24+
result = infer_obj_name(object(), depth=1)
25+
assert result is None
26+
27+
28+
def test_infer_obj_name_no_frame():
29+
"""Test infer_obj_name when frame inspection fails."""
30+
# This is hard to trigger without mocking, but we can test that the function
31+
# returns None gracefully when it can't find the object
32+
some_obj = object()
33+
34+
# Call with depth that would exceed the call stack
35+
result = infer_obj_name(some_obj, depth=1000)
36+
assert result is None
37+
38+
39+
global_obj = object()
40+
41+
42+
def test_infer_obj_name_locals_vs_globals():
43+
"""Test infer_obj_name prefers locals over globals."""
44+
result = infer_obj_name(global_obj, depth=1)
45+
assert result == 'global_obj'
46+
47+
# Assign a local name to the variable and ensure it is found with precedence over the global
48+
local_obj = global_obj
49+
result = infer_obj_name(global_obj, depth=1)
50+
assert result == 'local_obj'
51+
52+
# If we unbind the local name, should find the global name again
53+
del local_obj
54+
result = infer_obj_name(global_obj, depth=1)
55+
assert result == 'global_obj'

0 commit comments

Comments
 (0)