Skip to content

Commit fde6c9a

Browse files
authored
extend RunContext (#570)
1 parent 2289879 commit fde6c9a

File tree

3 files changed

+38
-33
lines changed

3 files changed

+38
-33
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040

4141
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
4242

43+
# while waiting for https://github.com/pydantic/logfire/issues/745
44+
try:
45+
import logfire._internal.stack_info
46+
except ImportError:
47+
pass
48+
else:
49+
from pathlib import Path
50+
51+
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
52+
4353
NoneType = type(None)
4454
EndStrategy = Literal['early', 'exhaustive']
4555
"""The strategy for handling multiple tool calls when a final result is found.
@@ -215,7 +225,7 @@ async def run(
215225
"""
216226
if infer_name and self.name is None:
217227
self._infer_name(inspect.currentframe())
218-
model_used, mode_selection = await self._get_model(model)
228+
model_used = await self._get_model(model)
219229

220230
deps = self._get_deps(deps)
221231
new_message_index = len(message_history) if message_history else 0
@@ -224,11 +234,10 @@ async def run(
224234
'{agent_name} run {prompt=}',
225235
prompt=user_prompt,
226236
agent=self,
227-
mode_selection=mode_selection,
228237
model_name=model_used.name(),
229238
agent_name=self.name or 'agent',
230239
) as run_span:
231-
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
240+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
232241
messages = await self._prepare_messages(user_prompt, message_history, run_context)
233242
run_context.messages = messages
234243

@@ -238,15 +247,14 @@ async def run(
238247
model_settings = merge_model_settings(self.model_settings, model_settings)
239248
usage_limits = usage_limits or UsageLimits()
240249

241-
run_step = 0
242250
while True:
243251
usage_limits.check_before_request(run_context.usage)
244252

245-
run_step += 1
246-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
253+
run_context.run_step += 1
254+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
247255
agent_model = await self._prepare_model(run_context)
248256

249-
with _logfire.span('model request', run_step=run_step) as model_req_span:
257+
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
250258
model_response, request_usage = await agent_model.request(messages, model_settings)
251259
model_req_span.set_attribute('response', model_response)
252260
model_req_span.set_attribute('usage', request_usage)
@@ -255,7 +263,7 @@ async def run(
255263
run_context.usage.incr(request_usage, requests=1)
256264
usage_limits.check_tokens(run_context.usage)
257265

258-
with _logfire.span('handle model response', run_step=run_step) as handle_span:
266+
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
259267
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
260268

261269
if tool_responses:
@@ -377,7 +385,7 @@ async def main():
377385
# f_back because `asynccontextmanager` adds one frame
378386
if frame := inspect.currentframe(): # pragma: no branch
379387
self._infer_name(frame.f_back)
380-
model_used, mode_selection = await self._get_model(model)
388+
model_used = await self._get_model(model)
381389

382390
deps = self._get_deps(deps)
383391
new_message_index = len(message_history) if message_history else 0
@@ -386,11 +394,10 @@ async def main():
386394
'{agent_name} run stream {prompt=}',
387395
prompt=user_prompt,
388396
agent=self,
389-
mode_selection=mode_selection,
390397
model_name=model_used.name(),
391398
agent_name=self.name or 'agent',
392399
) as run_span:
393-
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
400+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
394401
messages = await self._prepare_messages(user_prompt, message_history, run_context)
395402
run_context.messages = messages
396403

@@ -400,15 +407,14 @@ async def main():
400407
model_settings = merge_model_settings(self.model_settings, model_settings)
401408
usage_limits = usage_limits or UsageLimits()
402409

403-
run_step = 0
404410
while True:
405-
run_step += 1
411+
run_context.run_step += 1
406412
usage_limits.check_before_request(run_context.usage)
407413

408-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
414+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
409415
agent_model = await self._prepare_model(run_context)
410416

411-
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
417+
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
412418
async with agent_model.request_stream(messages, model_settings) as model_response:
413419
run_context.usage.requests += 1
414420
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
@@ -781,14 +787,14 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None:
781787

782788
self._function_tools[tool.name] = tool
783789

784-
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
790+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
785791
"""Create a model configured for this agent.
786792
787793
Args:
788794
model: model to use for this run, required if `model` was not set when creating the agent.
789795
790796
Returns:
791-
a tuple of `(model used, how the model was selected)`
797+
The model used
792798
"""
793799
model_: models.Model
794800
if some_model := self._override_model:
@@ -799,18 +805,15 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
799805
'(Even when `override(model=...)` is customizing the model that will actually be called)'
800806
)
801807
model_ = some_model.value
802-
mode_selection = 'override-model'
803808
elif model is not None:
804809
model_ = models.infer_model(model)
805-
mode_selection = 'custom'
806810
elif self.model is not None:
807811
# noinspection PyTypeChecker
808812
model_ = self.model = models.infer_model(self.model)
809-
mode_selection = 'from-agent'
810813
else:
811814
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
812815

813-
return model_, mode_selection
816+
return model_
814817

815818
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
816819
"""Build tools and create an agent model."""

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,20 @@ class RunContext(Generic[AgentDeps]):
4040

4141
deps: AgentDeps
4242
"""Dependencies for the agent."""
43-
retry: int
44-
"""Number of retries so far."""
45-
messages: list[_messages.ModelMessage]
46-
"""Messages exchanged in the conversation so far."""
47-
tool_name: str | None
48-
"""Name of the tool being called."""
4943
model: models.Model
5044
"""The model used in this run."""
5145
usage: Usage
5246
"""LLM usage associated with the run."""
47+
prompt: str
48+
"""The original user prompt passed to the run."""
49+
messages: list[_messages.ModelMessage] = field(default_factory=list)
50+
"""Messages exchanged in the conversation so far."""
51+
tool_name: str | None = None
52+
"""Name of the tool being called."""
53+
retry: int = 0
54+
"""Number of retries so far."""
55+
run_step: int = 0
56+
"""The current step in the run."""
5357

5458
def replace_with(
5559
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET

tests/test_logfire.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ async def my_ret(x: int) -> str:
9191
)
9292
assert summary.attributes[0] == snapshot(
9393
{
94-
'code.filepath': 'agent.py',
95-
'code.function': 'run',
94+
'code.filepath': 'test_logfire.py',
95+
'code.function': 'test_logfire',
9696
'code.lineno': 123,
9797
'prompt': 'Hello',
9898
'agent': IsJson(
@@ -111,7 +111,6 @@ async def my_ret(x: int) -> str:
111111
'model_settings': None,
112112
}
113113
),
114-
'mode_selection': 'from-agent',
115114
'model_name': 'test-model',
116115
'agent_name': 'my_agent',
117116
'logfire.msg_template': '{agent_name} run {prompt=}',
@@ -176,7 +175,6 @@ async def my_ret(x: int) -> str:
176175
'model': {'type': 'object', 'title': 'TestModel', 'x-python-datatype': 'dataclass'}
177176
},
178177
},
179-
'mode_selection': {},
180178
'model_name': {},
181179
'agent_name': {},
182180
'all_messages': {
@@ -263,8 +261,8 @@ async def my_ret(x: int) -> str:
263261
)
264262
assert summary.attributes[1] == snapshot(
265263
{
266-
'code.filepath': 'agent.py',
267-
'code.function': 'run',
264+
'code.filepath': 'test_logfire.py',
265+
'code.function': 'test_logfire',
268266
'code.lineno': IsInt(),
269267
'run_step': 1,
270268
'logfire.msg_template': 'preparing model and tools {run_step=}',

0 commit comments

Comments
 (0)