Skip to content

Commit 27dbef6

Browse files
authored
Unify model responses (#232)
1 parent 7d27c42 commit 27dbef6

32 files changed

+830
-739
lines changed

docs/agents.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,17 @@ except UnexpectedModelBehavior as e:
353353
timestamp=datetime.datetime(...),
354354
role='user',
355355
),
356-
ModelStructuredResponse(
357-
calls=[
358-
ToolCall(
356+
ModelResponse(
357+
parts=[
358+
ToolCallPart(
359359
tool_name='calc_volume',
360360
args=ArgsDict(args_dict={'size': 6}),
361361
tool_call_id=None,
362+
kind='tool-call',
362363
)
363364
],
365+
role='model-response',
364366
timestamp=datetime.datetime(...),
365-
role='model-structured-response',
366367
),
367368
RetryPrompt(
368369
content='Please try again.',
@@ -371,16 +372,17 @@ except UnexpectedModelBehavior as e:
371372
timestamp=datetime.datetime(...),
372373
role='retry-prompt',
373374
),
374-
ModelStructuredResponse(
375-
calls=[
376-
ToolCall(
375+
ModelResponse(
376+
parts=[
377+
ToolCallPart(
377378
tool_name='calc_volume',
378379
args=ArgsDict(args_dict={'size': 6}),
379380
tool_call_id=None,
381+
kind='tool-call',
380382
)
381383
],
384+
role='model-response',
382385
timestamp=datetime.datetime(...),
383-
role='model-structured-response',
384386
),
385387
]
386388
"""

docs/api/messages.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
- UserPrompt
99
- ToolReturn
1010
- RetryPrompt
11-
- ModelAnyResponse
12-
- ModelTextResponse
13-
- ModelStructuredResponse
11+
- ModelResponse
1412
- ToolCall
1513
- ArgsJson
1614
- ArgsObject

docs/message-history.md

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ print(result.all_messages())
4646
timestamp=datetime.datetime(...),
4747
role='user',
4848
),
49-
ModelTextResponse(
50-
content='Did you hear about the toothpaste scandal? They called it Colgate.',
49+
ModelResponse(
50+
parts=[
51+
TextPart(
52+
content='Did you hear about the toothpaste scandal? They called it Colgate.',
53+
kind='text',
54+
)
55+
],
56+
role='model-response',
5157
timestamp=datetime.datetime(...),
52-
role='model-text-response',
5358
),
5459
]
5560
"""
@@ -63,10 +68,15 @@ print(result.new_messages())
6368
timestamp=datetime.datetime(...),
6469
role='user',
6570
),
66-
ModelTextResponse(
67-
content='Did you hear about the toothpaste scandal? They called it Colgate.',
71+
ModelResponse(
72+
parts=[
73+
TextPart(
74+
content='Did you hear about the toothpaste scandal? They called it Colgate.',
75+
kind='text',
76+
)
77+
],
78+
role='model-response',
6879
timestamp=datetime.datetime(...),
69-
role='model-text-response',
7080
),
7181
]
7282
"""
@@ -113,10 +123,15 @@ async def main():
113123
timestamp=datetime.datetime(...),
114124
role='user',
115125
),
116-
ModelTextResponse(
117-
content='Did you hear about the toothpaste scandal? They called it Colgate.',
126+
ModelResponse(
127+
parts=[
128+
TextPart(
129+
content='Did you hear about the toothpaste scandal? They called it Colgate.',
130+
kind='text',
131+
)
132+
],
133+
role='model-response',
118134
timestamp=datetime.datetime(...),
119-
role='model-text-response',
120135
),
121136
]
122137
"""
@@ -164,20 +179,30 @@ print(result2.all_messages())
164179
timestamp=datetime.datetime(...),
165180
role='user',
166181
),
167-
ModelTextResponse(
168-
content='Did you hear about the toothpaste scandal? They called it Colgate.',
182+
ModelResponse(
183+
parts=[
184+
TextPart(
185+
content='Did you hear about the toothpaste scandal? They called it Colgate.',
186+
kind='text',
187+
)
188+
],
189+
role='model-response',
169190
timestamp=datetime.datetime(...),
170-
role='model-text-response',
171191
),
172192
UserPrompt(
173193
content='Explain?',
174194
timestamp=datetime.datetime(...),
175195
role='user',
176196
),
177-
ModelTextResponse(
178-
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
197+
ModelResponse(
198+
parts=[
199+
TextPart(
200+
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
201+
kind='text',
202+
)
203+
],
204+
role='model-response',
179205
timestamp=datetime.datetime(...),
180-
role='model-text-response',
181206
),
182207
]
183208
"""
@@ -214,20 +239,30 @@ print(result2.all_messages())
214239
timestamp=datetime.datetime(...),
215240
role='user',
216241
),
217-
ModelTextResponse(
218-
content='Did you hear about the toothpaste scandal? They called it Colgate.',
242+
ModelResponse(
243+
parts=[
244+
TextPart(
245+
content='Did you hear about the toothpaste scandal? They called it Colgate.',
246+
kind='text',
247+
)
248+
],
249+
role='model-response',
219250
timestamp=datetime.datetime(...),
220-
role='model-text-response',
221251
),
222252
UserPrompt(
223253
content='Explain?',
224254
timestamp=datetime.datetime(...),
225255
role='user',
226256
),
227-
ModelTextResponse(
228-
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
257+
ModelResponse(
258+
parts=[
259+
TextPart(
260+
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
261+
kind='text',
262+
)
263+
],
264+
role='model-response',
229265
timestamp=datetime.datetime(...),
230-
role='model-text-response',
231266
),
232267
]
233268
"""

docs/results.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ async def main():
302302
#> {'name': 'Ben', 'dob': date(1990, 1, 28), 'bio': 'Likes the chain the dog and the pyramid'}
303303
```
304304

305-
1. [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] streams the data as [`ModelStructuredResponse`][pydantic_ai.messages.ModelStructuredResponse] objects, thus iteration can't fail with a `ValidationError`.
305+
1. [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] streams the data as [`ModelResponse`][pydantic_ai.messages.ModelResponse] objects, thus iteration can't fail with a `ValidationError`.
306306
2. [`validate_structured_result`][pydantic_ai.result.StreamedRunResult.validate_structured_result] validates the data, `allow_partial=True` enables pydantic's [`experimental_allow_partial` flag on `TypeAdapter`][pydantic.type_adapter.TypeAdapter.validate_json].
307307

308308
_(This example is complete, it can be run "as is")_

docs/testing-evals.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ from dirty_equals import IsNow
9898
from pydantic_ai import models
9999
from pydantic_ai.models.test import TestModel
100100
from pydantic_ai.messages import (
101-
SystemPrompt,
102-
UserPrompt,
103-
ModelStructuredResponse,
104-
ToolCall,
105101
ArgsDict,
102+
ModelResponse,
103+
SystemPrompt,
104+
TextPart,
105+
ToolCallPart,
106106
ToolReturn,
107-
ModelTextResponse,
107+
UserPrompt,
108108
)
109109

110110
from fake_database import DatabaseConn
@@ -134,9 +134,9 @@ async def test_forecast():
134134
timestamp=IsNow(tz=timezone.utc), # (7)!
135135
role='user',
136136
),
137-
ModelStructuredResponse(
138-
calls=[
139-
ToolCall(
137+
ModelResponse(
138+
parts=[
139+
ToolCallPart(
140140
tool_name='weather_forecast',
141141
args=ArgsDict(
142142
args_dict={
@@ -148,7 +148,7 @@ async def test_forecast():
148148
)
149149
],
150150
timestamp=IsNow(tz=timezone.utc),
151-
role='model-structured-response',
151+
role='model-response',
152152
),
153153
ToolReturn(
154154
tool_name='weather_forecast',
@@ -157,10 +157,14 @@ async def test_forecast():
157157
timestamp=IsNow(tz=timezone.utc),
158158
role='tool-return',
159159
),
160-
ModelTextResponse(
161-
content='{"weather_forecast":"Sunny with a chance of rain"}',
160+
ModelResponse(
161+
parts=[
162+
TextPart(
163+
content='{"weather_forecast":"Sunny with a chance of rain"}',
164+
)
165+
],
162166
timestamp=IsNow(tz=timezone.utc),
163-
role='model-text-response',
167+
role='model-response',
164168
),
165169
]
166170
```
@@ -190,10 +194,8 @@ import pytest
190194
from pydantic_ai import models
191195
from pydantic_ai.messages import (
192196
Message,
193-
ModelAnyResponse,
194-
ModelStructuredResponse,
195-
ModelTextResponse,
196-
ToolCall,
197+
ModelResponse,
198+
ToolCallPart,
197199
)
198200
from pydantic_ai.models.function import AgentInfo, FunctionModel
199201

@@ -206,21 +208,19 @@ models.ALLOW_MODEL_REQUESTS = False
206208

207209
def call_weather_forecast( # (1)!
208210
messages: list[Message], info: AgentInfo
209-
) -> ModelAnyResponse:
211+
) -> ModelResponse:
210212
if len(messages) == 2:
211213
# first call, call the weather forecast tool
212214
user_prompt = messages[1]
213215
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
214216
assert m is not None
215217
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
216-
return ModelStructuredResponse(
217-
calls=[ToolCall.from_dict('weather_forecast', args)]
218-
)
218+
return ModelResponse(parts=[ToolCallPart.from_dict('weather_forecast', args)])
219219
else:
220220
# second call, return the forecast
221221
msg = messages[-1]
222222
assert msg.role == 'tool-return'
223-
return ModelTextResponse(f'The forecast is: {msg.content}')
223+
return ModelResponse.from_text(f'The forecast is: {msg.content}')
224224

225225

226226
async def test_forecast_future():

docs/tools.md

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,17 @@ print(dice_result.all_messages())
7777
timestamp=datetime.datetime(...),
7878
role='user',
7979
),
80-
ModelStructuredResponse(
81-
calls=[
82-
ToolCall(
83-
tool_name='roll_die', args=ArgsDict(args_dict={}), tool_call_id=None
80+
ModelResponse(
81+
parts=[
82+
ToolCallPart(
83+
tool_name='roll_die',
84+
args=ArgsDict(args_dict={}),
85+
tool_call_id=None,
86+
kind='tool-call',
8487
)
8588
],
89+
role='model-response',
8690
timestamp=datetime.datetime(...),
87-
role='model-structured-response',
8891
),
8992
ToolReturn(
9093
tool_name='roll_die',
@@ -93,16 +96,17 @@ print(dice_result.all_messages())
9396
timestamp=datetime.datetime(...),
9497
role='tool-return',
9598
),
96-
ModelStructuredResponse(
97-
calls=[
98-
ToolCall(
99+
ModelResponse(
100+
parts=[
101+
ToolCallPart(
99102
tool_name='get_player_name',
100103
args=ArgsDict(args_dict={}),
101104
tool_call_id=None,
105+
kind='tool-call',
102106
)
103107
],
108+
role='model-response',
104109
timestamp=datetime.datetime(...),
105-
role='model-structured-response',
106110
),
107111
ToolReturn(
108112
tool_name='get_player_name',
@@ -111,10 +115,15 @@ print(dice_result.all_messages())
111115
timestamp=datetime.datetime(...),
112116
role='tool-return',
113117
),
114-
ModelTextResponse(
115-
content="Congratulations Anne, you guessed correctly! You're a winner!",
118+
ModelResponse(
119+
parts=[
120+
TextPart(
121+
content="Congratulations Anne, you guessed correctly! You're a winner!",
122+
kind='text',
123+
)
124+
],
125+
role='model-response',
116126
timestamp=datetime.datetime(...),
117-
role='model-text-response',
118127
),
119128
]
120129
"""
@@ -151,7 +160,7 @@ sequenceDiagram
151160
activate LLM
152161
Note over LLM: LLM constructs final response
153162
154-
LLM ->> Agent: ModelTextResponse<br>"Congratulations Anne, ..."
163+
LLM ->> Agent: ModelResponse<br>"Congratulations Anne, ..."
155164
deactivate LLM
156165
Note over Agent: Game session complete
157166
```
@@ -215,7 +224,7 @@ To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models
215224

216225
```python {title="tool_schema.py"}
217226
from pydantic_ai import Agent
218-
from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
227+
from pydantic_ai.messages import Message, ModelResponse
219228
from pydantic_ai.models.function import AgentInfo, FunctionModel
220229

221230
agent = Agent()
@@ -233,7 +242,7 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
233242
return f'{a} {b} {c}'
234243

235244

236-
def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
245+
def print_schema(messages: list[Message], info: AgentInfo) -> ModelResponse:
237246
tool = info.function_tools[0]
238247
print(tool.description)
239248
#> Get me foobar.
@@ -255,7 +264,7 @@ def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
255264
'additionalProperties': False,
256265
}
257266
"""
258-
return ModelTextResponse(content='foobar')
267+
return ModelResponse.from_text(content='foobar')
259268

260269

261270
agent.run_sync('hello', model=FunctionModel(print_schema))

0 commit comments

Comments
 (0)