Skip to content

Commit be33820

Browse files
authored
rename retriever_context -> retriever (#68)
1 parent d6f1fc9 commit be33820

File tree

13 files changed

+42
-40
lines changed

13 files changed

+42
-40
lines changed

docs/agents.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ roulette_agent = Agent( # (1)!
3131
)
3232

3333

34-
@roulette_agent.retriever_context
34+
@roulette_agent.retriever
3535
async def roulette_wheel(ctx: CallContext[int], square: int) -> str: # (2)!
3636
"""check if the square is a winner"""
3737
return 'winner' if square == ctx.deps else 'loser'
@@ -179,7 +179,9 @@ They're useful when it is impractical or impossible to put all the context an ag
179179
There are two different decorator functions to register retrievers:
180180

181181
1. [`@agent.retriever_plain`][pydantic_ai.Agent.retriever_plain] — for retrievers that don't need access to the agent [context][pydantic_ai.dependencies.CallContext]
182-
2. [`@agent.retriever_context`][pydantic_ai.Agent.retriever_context] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext]
182+
2. [`@agent.retriever`][pydantic_ai.Agent.retriever] — for retrievers that do need access to the agent [context][pydantic_ai.dependencies.CallContext]
183+
184+
`@agent.retriever` is the default since in the majority of cases retrievers will need access to the agent context.
183185

184186
Here's an example using both:
185187

@@ -205,7 +207,7 @@ def roll_dice() -> str:
205207
return str(random.randint(1, 6))
206208

207209

208-
@agent.retriever_context # (4)!
210+
@agent.retriever # (4)!
209211
def get_player_name(ctx: CallContext[str]) -> str:
210212
"""Get the player's name."""
211213
return ctx.deps
@@ -219,7 +221,7 @@ print(dice_result.data)
219221
1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model.
220222
2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency.
221223
3. This retriever doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case.
222-
4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name.
224+
4. This retriever needs the player's name, so it uses `CallContext` to access dependencies which are just the player's name in this case.
223225
5. Run the agent, passing the player's name as the dependency.
224226

225227
_(This example is complete, it can be run "as is")_
@@ -362,7 +364,7 @@ Validation errors from both retriever parameter validation and [structured resul
362364

363365
You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [retriever](#retrievers) or [result validator functions](results.md#result-validators-functions) to tell the model it should retry.
364366

365-
- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever_context], or a [result validator][pydantic_ai.Agent.__init__].
367+
- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific retriever][pydantic_ai.Agent.retriever], or a [result validator][pydantic_ai.Agent.__init__].
366368
- You can access the current retry count from within a retriever or result validator via [`ctx.retry`][pydantic_ai.dependencies.CallContext].
367369

368370
Here's an example:
@@ -386,7 +388,7 @@ agent = Agent(
386388
)
387389

388390

389-
@agent.retriever_context(retries=2)
391+
@agent.retriever(retries=2)
390392
def get_user_by_name(ctx: CallContext[DatabaseConn], name: str) -> int:
391393
"""Get a user's ID from their full name."""
392394
print(name)

docs/api/agent.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
- last_run_messages
1414
- system_prompt
1515
- retriever_plain
16-
- retriever_context
16+
- retriever
1717
- result_validator

docs/dependencies.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ async def get_system_prompt(ctx: CallContext[MyDeps]) -> str:
188188
return f'Prompt: {response.text}'
189189

190190

191-
@agent.retriever_context # (1)!
191+
@agent.retriever # (1)!
192192
async def get_joke_material(ctx: CallContext[MyDeps], subject: str) -> str:
193193
response = await ctx.deps.http_client.get(
194194
'https://example.com#jokes',
@@ -220,7 +220,7 @@ async def main():
220220
#> Did you hear about the toothpaste scandal? They called it Colgate.
221221
```
222222

223-
1. To pass `CallContext` and to a retriever, us the [`retriever_context`][pydantic_ai.Agent.retriever_context] decorator.
223+
1. To pass `CallContext` and to a retriever, us the [`retriever`][pydantic_ai.Agent.retriever] decorator.
224224
2. `CallContext` may optionally be passed to a [`result_validator`][pydantic_ai.Agent.result_validator] function as the first argument.
225225

226226
_(This example is complete, it can be run "as is")_
@@ -324,7 +324,7 @@ joke_agent = Agent(
324324
factory_agent = Agent('gemini-1.5-pro', result_type=list[str])
325325

326326

327-
@joke_agent.retriever_context
327+
@joke_agent.retriever
328328
async def joke_factory(ctx: CallContext[MyDeps], count: int) -> str:
329329
r = await ctx.deps.factory_agent.run(f'Please generate {count} jokes.')
330330
return '\n'.join(r.data)

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str:
9494
return f"The customer's name is {customer_name!r}"
9595

9696

97-
@support_agent.retriever_context # (6)!
97+
@support_agent.retriever # (6)!
9898
async def customer_balance(
9999
ctx: CallContext[SupportDependencies], include_pending: bool
100100
) -> str:

pydantic_ai/_pydantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, R
7878

7979
if index == 0 and takes_ctx:
8080
if not _is_call_ctx(annotation):
81-
errors.append('First argument must be a CallContext instance when using `.retriever_context`')
81+
errors.append('First argument must be a CallContext instance when using `.retriever`')
8282
continue
8383
elif not takes_ctx and _is_call_ctx(annotation):
84-
errors.append('CallContext instance can only be used with `.retriever_context`')
84+
errors.append('CallContext instance can only be used with `.retriever`')
8585
continue
8686
elif index != 0 and _is_call_ctx(annotation):
8787
errors.append('CallContext instance can only be used as the first argument')

pydantic_ai/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,16 @@ def result_validator(
338338
return func
339339

340340
@overload
341-
def retriever_context(
341+
def retriever(
342342
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
343343
) -> _r.Retriever[AgentDeps, RetrieverParams]: ...
344344

345345
@overload
346-
def retriever_context(
346+
def retriever(
347347
self, /, *, retries: int | None = None
348348
) -> Callable[[RetrieverContextFunc[AgentDeps, RetrieverParams]], _r.Retriever[AgentDeps, RetrieverParams]]: ...
349349

350-
def retriever_context(
350+
def retriever(
351351
self,
352352
func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
353353
/,

pydantic_ai_examples/bank_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def add_customer_name(ctx: CallContext[SupportDependencies]) -> str:
6262
return f"The customer's name is {customer_name!r}"
6363

6464

65-
@support_agent.retriever_context
65+
@support_agent.retriever
6666
async def customer_balance(
6767
ctx: CallContext[SupportDependencies], include_pending: bool
6868
) -> str:

pydantic_ai_examples/rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Deps:
5151
agent = Agent('openai:gpt-4o', deps_type=Deps)
5252

5353

54-
@agent.retriever_context
54+
@agent.retriever
5555
async def retrieve(context: CallContext[Deps], search_query: str) -> str:
5656
"""Retrieve documentation sections based on a search query.
5757

pydantic_ai_examples/weather_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Deps:
4141
)
4242

4343

44-
@weather_agent.retriever_context
44+
@weather_agent.retriever
4545
async def get_lat_lng(
4646
ctx: CallContext[Deps], location_description: str
4747
) -> dict[str, float]:
@@ -71,7 +71,7 @@ async def get_lat_lng(
7171
raise ModelRetry('Could not find the location')
7272

7373

74-
@weather_agent.retriever_context
74+
@weather_agent.retriever
7575
async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[str, Any]:
7676
"""Get the weather at a location.
7777

tests/models/test_model_function.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def get_location(location_description: str) -> str:
119119
return json.dumps(lat_lng)
120120

121121

122-
@weather_agent.retriever_context
122+
@weather_agent.retriever
123123
async def get_weather(_: CallContext[None], lat: int, lng: int):
124124
if (lat, lng) == (51, 0):
125125
# it always rains in London
@@ -199,7 +199,7 @@ async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelAny
199199
var_args_agent = Agent(FunctionModel(call_function_model), deps_type=int)
200200

201201

202-
@var_args_agent.retriever_context
202+
@var_args_agent.retriever
203203
def get_var_args(ctx: CallContext[int], *args: int):
204204
assert ctx.deps == 123
205205
return json.dumps({'args': args})
@@ -233,7 +233,7 @@ async def call_retriever(messages: list[Message], info: AgentInfo) -> ModelAnyRe
233233
def test_deps_none():
234234
agent = Agent(FunctionModel(call_retriever))
235235

236-
@agent.retriever_context
236+
@agent.retriever
237237
async def get_none(ctx: CallContext[None]):
238238
nonlocal called
239239

@@ -259,7 +259,7 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str:
259259
return ''
260260

261261
agent = Agent(FunctionModel(call_retriever), deps_type=tuple[str, str])
262-
agent.retriever_context(get_check_foobar)
262+
agent.retriever(get_check_foobar)
263263
called = False
264264
agent.run_sync('Hello', deps=('foo', 'bar'))
265265
assert called
@@ -277,12 +277,12 @@ def test_model_arg():
277277
agent_all = Agent()
278278

279279

280-
@agent_all.retriever_context
280+
@agent_all.retriever
281281
async def foo(_: CallContext[None], x: int) -> str:
282282
return str(x + 1)
283283

284284

285-
@agent_all.retriever_context(retries=3)
285+
@agent_all.retriever(retries=3)
286286
def bar(ctx, x: int) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType]
287287
return str(x + 2)
288288

0 commit comments

Comments
 (0)