Skip to content

Commit e551ee1

Browse files
committed
Merge remote-tracking branch 'pydantic/main' into qian/dbos-agent
2 parents 81a9d7a + 10339bb commit e551ee1

26 files changed

+556
-277
lines changed

docs/cli.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,25 @@ async def main():
108108
```
109109

110110
_(You'll need to add `asyncio.run(main())` to run `main`)_
111+
112+
### Message History
113+
114+
Both `Agent.to_cli()` and `Agent.to_cli_sync()` support a `message_history` parameter, allowing you to continue an existing conversation or provide conversation context:
115+
116+
```python {title="agent_with_history.py" test="skip"}
117+
from pydantic_ai import Agent
118+
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, UserPromptPart, TextPart
119+
120+
agent = Agent('openai:gpt-4.1')
121+
122+
# Create some conversation history
123+
message_history: list[ModelMessage] = [
124+
ModelRequest([UserPromptPart(content='What is 2+2?')]),
125+
ModelResponse([TextPart(content='2+2 equals 4.')])
126+
]
127+
128+
# Start CLI with existing conversation context
129+
agent.to_cli_sync(message_history=message_history)
130+
```
131+
132+
The CLI will start with the provided conversation history, allowing the agent to refer back to previous exchanges and maintain context throughout the session.

docs/message-history.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`A
1818

1919
E.g. you've awaited one of the following coroutines:
2020

21-
* [`StreamedRunResult.stream()`][pydantic_ai.result.StreamedRunResult.stream]
21+
* [`StreamedRunResult.stream_output()`][pydantic_ai.result.StreamedRunResult.stream_output]
2222
* [`StreamedRunResult.stream_text()`][pydantic_ai.result.StreamedRunResult.stream_text]
23-
* [`StreamedRunResult.stream_structured()`][pydantic_ai.result.StreamedRunResult.stream_structured]
23+
* [`StreamedRunResult.stream_responses()`][pydantic_ai.result.StreamedRunResult.stream_responses]
2424
* [`StreamedRunResult.get_output()`][pydantic_ai.result.StreamedRunResult.get_output]
2525

2626
**Note:** The final result message will NOT be added to result messages if you use [`.stream_text(delta=True)`][pydantic_ai.result.StreamedRunResult.stream_text] since in this case the result content is never built as one string.

docs/output.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ agent = Agent(
573573
async def main():
574574
user_input = 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.'
575575
async with agent.run_stream(user_input) as result:
576-
async for profile in result.stream():
576+
async for profile in result.stream_output():
577577
print(profile)
578578
#> {'name': 'Ben'}
579579
#> {'name': 'Ben'}
@@ -609,9 +609,9 @@ agent = Agent('openai:gpt-4o', output_type=UserProfile)
609609
async def main():
610610
user_input = 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.'
611611
async with agent.run_stream(user_input) as result:
612-
async for message, last in result.stream_structured(debounce_by=0.01): # (1)!
612+
async for message, last in result.stream_responses(debounce_by=0.01): # (1)!
613613
try:
614-
profile = await result.validate_structured_output( # (2)!
614+
profile = await result.validate_response_output( # (2)!
615615
message,
616616
allow_partial=not last,
617617
)
@@ -627,8 +627,8 @@ async def main():
627627
#> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'}
628628
```
629629

630-
1. [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] streams the data as [`ModelResponse`][pydantic_ai.messages.ModelResponse] objects, thus iteration can't fail with a `ValidationError`.
631-
2. [`validate_structured_output`][pydantic_ai.result.StreamedRunResult.validate_structured_output] validates the data, `allow_partial=True` enables pydantic's [`experimental_allow_partial` flag on `TypeAdapter`][pydantic.type_adapter.TypeAdapter.validate_json].
630+
1. [`stream_responses`][pydantic_ai.result.StreamedRunResult.stream_responses] streams the data as [`ModelResponse`][pydantic_ai.messages.ModelResponse] objects, thus iteration can't fail with a `ValidationError`.
631+
2. [`validate_response_output`][pydantic_ai.result.StreamedRunResult.validate_response_output] validates the data, `allow_partial=True` enables pydantic's [`experimental_allow_partial` flag on `TypeAdapter`][pydantic.type_adapter.TypeAdapter.validate_json].
632632

633633
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
634634

examples/pydantic_ai_examples/chat_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def stream_messages():
127127
messages = await database.get_messages()
128128
# run the agent with the user prompt and the chat history
129129
async with agent.run_stream(prompt, message_history=messages) as result:
130-
async for text in result.stream(debounce_by=0.01):
130+
async for text in result.stream_output(debounce_by=0.01):
131131
# text here is a `str` and the frontend wants
132132
# JSON encoded ModelResponse, so we create one
133133
m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp())

examples/pydantic_ai_examples/stream_markdown.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def main():
4242
console.log(f'Using model: {model}')
4343
with Live('', console=console, vertical_overflow='visible') as live:
4444
async with agent.run_stream(prompt, model=model) as result:
45-
async for message in result.stream():
45+
async for message in result.stream_output():
4646
live.update(Markdown(message))
4747
console.log(result.usage())
4848
else:

examples/pydantic_ai_examples/stream_whales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def main():
5151
) as result:
5252
console.print('Response:', style='green')
5353

54-
async for whales in result.stream(debounce_by=0.01):
54+
async for whales in result.stream_output(debounce_by=0.01):
5555
table = Table(
5656
title='Species of Whale',
5757
caption='Streaming Structured responses from GPT-4',

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -678,22 +678,28 @@ def _run_span_end_attributes(
678678
self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
679679
):
680680
if settings.version == 1:
681-
attr_name = 'all_messages_events'
682-
value = [
683-
InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)
684-
]
681+
attrs = {
682+
'all_messages_events': json.dumps(
683+
[
684+
InstrumentedModel.event_to_dict(e)
685+
for e in settings.messages_to_otel_events(state.message_history)
686+
]
687+
)
688+
}
685689
else:
686-
attr_name = 'pydantic_ai.all_messages'
687-
value = settings.messages_to_otel_messages(state.message_history)
690+
attrs = {
691+
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
692+
**settings.system_instructions_attributes(self._instructions),
693+
}
688694

689695
return {
690696
**usage.opentelemetry_attributes(),
691-
attr_name: json.dumps(value),
697+
**attrs,
692698
'logfire.json_schema': json.dumps(
693699
{
694700
'type': 'object',
695701
'properties': {
696-
attr_name: {'type': 'array'},
702+
**{attr: {'type': 'array'} for attr in attrs.keys()},
697703
'final_result': {'type': 'object'},
698704
},
699705
}

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
from __future__ import annotations
22

33
import warnings
4-
from collections.abc import Sequence
4+
from collections.abc import AsyncIterator, Sequence
5+
from contextlib import AbstractAsyncContextManager
56
from dataclasses import replace
67
from typing import Any, Callable
78

89
from pydantic.errors import PydanticUserError
9-
from temporalio.client import ClientConfig, Plugin as ClientPlugin
10+
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
1011
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
11-
from temporalio.converter import DefaultPayloadConverter
12-
from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig
12+
from temporalio.converter import DataConverter, DefaultPayloadConverter
13+
from temporalio.service import ConnectConfig, ServiceClient
14+
from temporalio.worker import (
15+
Plugin as WorkerPlugin,
16+
Replayer,
17+
ReplayerConfig,
18+
Worker,
19+
WorkerConfig,
20+
WorkflowReplayResult,
21+
)
1322
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
1423

1524
from ...exceptions import UserError
@@ -31,17 +40,15 @@
3140
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
3241
"""Temporal client and worker plugin for Pydantic AI."""
3342

34-
def configure_client(self, config: ClientConfig) -> ClientConfig:
35-
if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in (
36-
DefaultPayloadConverter,
37-
PydanticPayloadConverter,
38-
):
39-
warnings.warn( # pragma: no cover
40-
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
41-
)
43+
def init_client_plugin(self, next: ClientPlugin) -> None:
44+
self.next_client_plugin = next
4245

43-
config['data_converter'] = pydantic_data_converter
44-
return super().configure_client(config)
46+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
47+
self.next_worker_plugin = next
48+
49+
def configure_client(self, config: ClientConfig) -> ClientConfig:
50+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
51+
return self.next_client_plugin.configure_client(config)
4552

4653
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
4754
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
@@ -67,7 +74,35 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6774
PydanticUserError,
6875
]
6976

70-
return super().configure_worker(config)
77+
return self.next_worker_plugin.configure_worker(config)
78+
79+
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
80+
return await self.next_client_plugin.connect_service_client(config)
81+
82+
async def run_worker(self, worker: Worker) -> None:
83+
await self.next_worker_plugin.run_worker(worker)
84+
85+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
86+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
87+
return self.next_worker_plugin.configure_replayer(config)
88+
89+
def run_replayer(
90+
self,
91+
replayer: Replayer,
92+
histories: AsyncIterator[WorkflowHistory],
93+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
94+
return self.next_worker_plugin.run_replayer(replayer, histories)
95+
96+
def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
97+
if converter and converter.payload_converter_class not in (
98+
DefaultPayloadConverter,
99+
PydanticPayloadConverter,
100+
):
101+
warnings.warn( # pragma: no cover
102+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
103+
)
104+
105+
return pydantic_data_converter
71106

72107

73108
class AgentPlugin(WorkerPlugin):
@@ -76,8 +111,24 @@ class AgentPlugin(WorkerPlugin):
76111
def __init__(self, agent: TemporalAgent[Any, Any]):
77112
self.agent = agent
78113

114+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
115+
self.next_worker_plugin = next
116+
79117
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
80118
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
81119
# Activities are checked for name conflicts by Temporal.
82120
config['activities'] = [*activities, *self.agent.temporal_activities]
83-
return super().configure_worker(config)
121+
return self.next_worker_plugin.configure_worker(config)
122+
123+
async def run_worker(self, worker: Worker) -> None:
124+
await self.next_worker_plugin.run_worker(worker)
125+
126+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
127+
return self.next_worker_plugin.configure_replayer(config)
128+
129+
def run_replayer(
130+
self,
131+
replayer: Replayer,
132+
histories: AsyncIterator[WorkflowHistory],
133+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
134+
return self.next_worker_plugin.run_replayer(replayer, histories)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire
2525
self.setup_logfire = setup_logfire
2626
self.metrics = metrics
2727

28+
def init_client_plugin(self, next: ClientPlugin) -> None:
29+
self.next_client_plugin = next
30+
2831
def configure_client(self, config: ClientConfig) -> ClientConfig:
2932
interceptors = config.get('interceptors', [])
3033
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
31-
return super().configure_client(config)
34+
return self.next_client_plugin.configure_client(config)
3235

3336
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
3437
logfire = self.setup_logfire()
@@ -45,4 +48,4 @@ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
4548
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
4649
)
4750

48-
return await super().connect_service_client(config)
51+
return await self.next_client_plugin.connect_service_client(config)

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,27 +236,36 @@ def handle_messages(self, input_messages: list[ModelMessage], response: ModelRes
236236
if response.provider_details and 'finish_reason' in response.provider_details:
237237
output_message['finish_reason'] = response.provider_details['finish_reason']
238238
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
239+
system_instructions_attributes = self.system_instructions_attributes(instructions)
239240
attributes = {
240241
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
241242
'gen_ai.output.messages': json.dumps([output_message]),
243+
**system_instructions_attributes,
242244
'logfire.json_schema': json.dumps(
243245
{
244246
'type': 'object',
245247
'properties': {
246248
'gen_ai.input.messages': {'type': 'array'},
247249
'gen_ai.output.messages': {'type': 'array'},
248-
**({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
250+
**(
251+
{'gen_ai.system_instructions': {'type': 'array'}}
252+
if system_instructions_attributes
253+
else {}
254+
),
249255
'model_request_parameters': {'type': 'object'},
250256
},
251257
}
252258
),
253259
}
254-
if instructions is not None:
255-
attributes['gen_ai.system_instructions'] = json.dumps(
256-
[_otel_messages.TextPart(type='text', content=instructions)]
257-
)
258260
span.set_attributes(attributes)
259261

262+
def system_instructions_attributes(self, instructions: str | None) -> dict[str, str]:
263+
if instructions and self.include_content:
264+
return {
265+
'gen_ai.system_instructions': json.dumps([_otel_messages.TextPart(type='text', content=instructions)]),
266+
}
267+
return {}
268+
260269
def _emit_events(self, span: Span, events: list[Event]) -> None:
261270
if self.event_mode == 'logs':
262271
for event in events:

0 commit comments

Comments
 (0)