Skip to content

Commit 763e7bc

Browse files
authored
Support callable classes as history processors (#2988)
1 parent 127624d commit 763e7bc

File tree

3 files changed

+301
-10
lines changed

3 files changed

+301
-10
lines changed

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,31 +231,39 @@ def function_schema( # noqa: C901
231231

232232
WithCtx = Callable[Concatenate[RunContext[Any], P], R]
233233
WithoutCtx = Callable[P, R]
234-
TargetFunc = WithCtx[P, R] | WithoutCtx[P, R]
234+
TargetCallable = WithCtx[P, R] | WithoutCtx[P, R]
235235

236236

237-
def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
238-
"""Check if a function takes a `RunContext` first argument.
237+
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]:
238+
"""Check if a callable takes a `RunContext` first argument.
239239
240240
Args:
241-
function: The function to check.
241+
callable_obj: The callable to check.
242242
243243
Returns:
244-
`True` if the function takes a `RunContext` as first argument, `False` otherwise.
244+
`True` if the callable takes a `RunContext` as first argument, `False` otherwise.
245245
"""
246246
try:
247-
sig = signature(function)
248-
except ValueError: # pragma: no cover
249-
return False # pragma: no cover
247+
sig = signature(callable_obj)
248+
except ValueError:
249+
return False
250250
try:
251251
first_param_name = next(iter(sig.parameters.keys()))
252252
except StopIteration:
253253
return False
254254
else:
255-
type_hints = _typing_extra.get_function_type_hints(function)
255+
# See https://github.com/pydantic/pydantic/pull/11451 for a similar implementation in Pydantic
256+
if not isinstance(callable_obj, _decorators._function_like): # pyright: ignore[reportPrivateUsage]
257+
call_func = getattr(type(callable_obj), '__call__', None)
258+
if call_func is not None:
259+
callable_obj = call_func
260+
else:
261+
return False # pragma: no cover
262+
263+
type_hints = _typing_extra.get_function_type_hints(_decorators.unwrap_wrapped_function(callable_obj))
256264
annotation = type_hints.get(first_param_name)
257265
if annotation is None:
258-
return False # pragma: no cover
266+
return False
259267
return True is not sig.empty and _is_call_ctx(annotation)
260268

261269

tests/test_function_schema.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from functools import partial
2+
from typing import Any
3+
from unittest.mock import Mock
4+
5+
from pydantic_ai import RunContext
6+
from pydantic_ai._function_schema import _takes_ctx # type: ignore
7+
8+
9+
def test_regular_function_with_ctx():
10+
"""Test regular function that takes RunContext as first parameter."""
11+
12+
def func_with_ctx(ctx: RunContext[Any], x: int) -> str: ... # pragma: no cover
13+
14+
assert _takes_ctx(func_with_ctx) is True
15+
16+
17+
def test_regular_function_without_ctx():
18+
"""Test regular function that doesn't take RunContext as first parameter."""
19+
20+
def func_without_ctx(x: int, y: str) -> str: ... # pragma: no cover
21+
22+
assert _takes_ctx(func_without_ctx) is False
23+
24+
25+
def test_regular_function_no_params():
26+
"""Test regular function with no parameters."""
27+
28+
def func_no_params() -> str: ... # pragma: no cover
29+
30+
assert _takes_ctx(func_no_params) is False
31+
32+
33+
def test_regular_function_ctx_not_first():
34+
"""Test regular function where RunContext is not the first parameter."""
35+
36+
def func_ctx_not_first(x: int, ctx: RunContext[Any]) -> str: ... # pragma: no cover
37+
38+
assert _takes_ctx(func_ctx_not_first) is False
39+
40+
41+
def test_partial_function_with_ctx():
42+
"""Test partial function where original function takes RunContext as first parameter."""
43+
44+
def original_func(ctx: RunContext[Any], x: int, y: str) -> str: ... # pragma: no cover
45+
46+
# Create partial with y bound
47+
partial_func = partial(original_func, y='bound')
48+
49+
assert _takes_ctx(partial_func) is True
50+
51+
52+
def test_partial_function_without_ctx():
53+
"""Test partial function where original function doesn't take RunContext."""
54+
55+
def original_func(x: int, y: str, z: float) -> str: ... # pragma: no cover
56+
57+
# Create partial with z bound
58+
partial_func = partial(original_func, z=3.14)
59+
60+
assert _takes_ctx(partial_func) is False
61+
62+
63+
def test_partial_function_ctx_bound():
64+
"""Test partial function where RunContext parameter is bound."""
65+
66+
def original_func(ctx: RunContext[Any], x: int, y: str) -> str: ... # pragma: no cover
67+
68+
mock_ctx = Mock(spec=RunContext[Any])
69+
partial_func = partial(original_func, ctx=mock_ctx)
70+
71+
assert _takes_ctx(partial_func) is True
72+
73+
74+
def test_callable_class_with_ctx():
75+
"""Test callable class where __call__ takes RunContext as first parameter."""
76+
77+
class CallableWithCtx:
78+
def __call__(self, ctx: RunContext[Any], x: int) -> str: ... # pragma: no cover
79+
80+
callable_obj = CallableWithCtx()
81+
82+
assert _takes_ctx(callable_obj) is True
83+
84+
85+
def test_callable_class_without_ctx():
86+
"""Test callable class where __call__ doesn't take RunContext."""
87+
88+
class CallableWithoutCtx:
89+
def __call__(self, x: int, y: str) -> str: ... # pragma: no cover
90+
91+
callable_obj = CallableWithoutCtx()
92+
93+
assert _takes_ctx(callable_obj) is False
94+
95+
96+
def test_callable_class_ctx_not_first():
97+
"""Test callable class where RunContext is not the first parameter."""
98+
99+
class CallableCtxNotFirst:
100+
def __call__(self, x: int, ctx: RunContext[Any]) -> str: ... # pragma: no cover
101+
102+
callable_obj = CallableCtxNotFirst()
103+
104+
assert _takes_ctx(callable_obj) is False
105+
106+
107+
def test_method_with_ctx():
108+
"""Test bound method that takes RunContext as first parameter (after )."""
109+
110+
class TestClass:
111+
def method_with_ctx(self, ctx: RunContext[Any], x: int) -> str: ... # pragma: no cover
112+
113+
obj = TestClass()
114+
bound_method = obj.method_with_ctx
115+
116+
assert _takes_ctx(bound_method) is True
117+
118+
119+
def test_method_without_ctx():
120+
"""Test bound method that doesn't take RunContext."""
121+
122+
class TestClass:
123+
def method_without_ctx(self, x: int, y: str) -> str: ... # pragma: no cover
124+
125+
obj = TestClass()
126+
bound_method = obj.method_without_ctx
127+
128+
assert _takes_ctx(bound_method) is False
129+
130+
131+
def test_static_method_with_ctx():
132+
"""Test static method that takes RunContext as first parameter."""
133+
134+
class TestClass:
135+
@staticmethod
136+
def static_method_with_ctx(ctx: RunContext[Any], x: int) -> str: ... # pragma: no cover
137+
138+
assert _takes_ctx(TestClass.static_method_with_ctx) is True
139+
140+
141+
def test_static_method_without_ctx():
142+
"""Test static method that doesn't take RunContext."""
143+
144+
class TestClass:
145+
@staticmethod
146+
def static_method_without_ctx(x: int, y: str) -> str: ... # pragma: no cover
147+
148+
assert _takes_ctx(TestClass.static_method_without_ctx) is False
149+
150+
151+
def test_class_method_with_ctx():
152+
"""Test class method that takes RunContext as first parameter (after cls)."""
153+
154+
class TestClass:
155+
@classmethod
156+
def class_method_with_ctx(cls, ctx: RunContext[Any], x: int) -> str: ... # pragma: no cover
157+
158+
assert _takes_ctx(TestClass.class_method_with_ctx) is True
159+
160+
161+
def test_class_method_without_ctx():
162+
"""Test class method that doesn't take RunContext."""
163+
164+
class TestClass:
165+
@classmethod
166+
def class_method_without_ctx(cls, x: int, y: str) -> str: ... # pragma: no cover
167+
168+
assert _takes_ctx(TestClass.class_method_without_ctx) is False
169+
170+
171+
def test_function_no_annotations():
172+
"""Test function with no type annotations."""
173+
174+
def func_no_annotations(ctx, x): # type: ignore
175+
... # pragma: no cover
176+
177+
# Without annotations, _takes_ctx should return False
178+
assert _takes_ctx(func_no_annotations) is False # type: ignore
179+
180+
181+
def test_function_wrong_annotation_type():
182+
"""Test function with wrong annotation type for first parameter."""
183+
184+
def func_wrong_annotation(ctx: str, x: int) -> str: ... # pragma: no cover
185+
186+
assert _takes_ctx(func_wrong_annotation) is False
187+
188+
189+
def test_lambda_with_ctx():
190+
"""Test lambda function that takes RunContext as first parameter."""
191+
lambda_with_ctx = lambda ctx, x: f'{ctx.deps} {x}' # type: ignore # noqa: E731
192+
193+
# Lambda without annotations should return False
194+
assert _takes_ctx(lambda_with_ctx) is False # type: ignore
195+
196+
197+
def test_builtin_function():
198+
"""Test builtin function."""
199+
assert _takes_ctx(len) is False
200+
assert _takes_ctx(str) is False
201+
assert _takes_ctx(int) is False

tests/test_history_processor.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,3 +718,85 @@ def return_new_history(messages: list[ModelMessage]) -> list[ModelMessage]:
718718

719719
with pytest.raises(UserError, match='Processed history must end with a `ModelRequest`.'):
720720
await agent.run('foobar')
721+
722+
723+
async def test_callable_class_history_processor_no_op(
724+
function_model: FunctionModel, received_messages: list[ModelMessage]
725+
):
726+
class NoOpHistoryProcessor:
727+
def __call__(self, messages: list[ModelMessage]) -> list[ModelMessage]:
728+
return messages
729+
730+
agent = Agent(function_model, history_processors=[NoOpHistoryProcessor()])
731+
732+
message_history = [
733+
ModelRequest(parts=[UserPromptPart(content='Previous question')]),
734+
ModelResponse(parts=[TextPart(content='Previous answer')]),
735+
]
736+
737+
with capture_run_messages() as captured_messages:
738+
result = await agent.run('New question', message_history=message_history)
739+
740+
assert received_messages == snapshot(
741+
[
742+
ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]),
743+
ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()),
744+
ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]),
745+
]
746+
)
747+
assert captured_messages == result.all_messages()
748+
assert result.all_messages() == snapshot(
749+
[
750+
ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]),
751+
ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()),
752+
ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]),
753+
ModelResponse(
754+
parts=[TextPart(content='Provider response')],
755+
usage=RequestUsage(input_tokens=54, output_tokens=4),
756+
model_name='function:capture_model_function:capture_model_stream_function',
757+
timestamp=IsDatetime(),
758+
),
759+
]
760+
)
761+
assert result.new_messages() == result.all_messages()[-2:]
762+
763+
764+
async def test_callable_class_history_processor_with_ctx_no_op(
765+
function_model: FunctionModel, received_messages: list[ModelMessage]
766+
):
767+
class NoOpHistoryProcessorWithCtx:
768+
def __call__(self, _: RunContext, messages: list[ModelMessage]) -> list[ModelMessage]:
769+
return messages
770+
771+
agent = Agent(function_model, history_processors=[NoOpHistoryProcessorWithCtx()])
772+
773+
message_history = [
774+
ModelRequest(parts=[UserPromptPart(content='Previous question')]),
775+
ModelResponse(parts=[TextPart(content='Previous answer')]),
776+
]
777+
778+
with capture_run_messages() as captured_messages:
779+
result = await agent.run('New question', message_history=message_history)
780+
781+
assert received_messages == snapshot(
782+
[
783+
ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]),
784+
ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()),
785+
ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]),
786+
]
787+
)
788+
assert captured_messages == result.all_messages()
789+
assert result.all_messages() == snapshot(
790+
[
791+
ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]),
792+
ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()),
793+
ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]),
794+
ModelResponse(
795+
parts=[TextPart(content='Provider response')],
796+
usage=RequestUsage(input_tokens=54, output_tokens=4),
797+
model_name='function:capture_model_function:capture_model_stream_function',
798+
timestamp=IsDatetime(),
799+
),
800+
]
801+
)
802+
assert result.new_messages() == result.all_messages()[-2:]

0 commit comments

Comments
 (0)