Skip to content

Commit b08874b

Browse files
Change: Internal Discriminator Modification (#247)
1 parent 39cec5f commit b08874b

File tree

14 files changed

+197
-108
lines changed

14 files changed

+197
-108
lines changed

docs/agents.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,37 +352,41 @@ except UnexpectedModelBehavior as e:
352352
content='Please get me the volume of a box with size 6.',
353353
timestamp=datetime.datetime(...),
354354
role='user',
355+
message_kind='user-prompt',
355356
),
356357
ModelResponse(
357358
parts=[
358359
ToolCallPart(
359360
tool_name='calc_volume',
360361
args=ArgsDict(args_dict={'size': 6}),
361362
tool_call_id=None,
362-
kind='tool-call',
363+
part_kind='tool-call',
363364
)
364365
],
365-
role='model-response',
366366
timestamp=datetime.datetime(...),
367+
role='model',
368+
message_kind='model-response',
367369
),
368370
RetryPrompt(
369371
content='Please try again.',
370372
tool_name='calc_volume',
371373
tool_call_id=None,
372374
timestamp=datetime.datetime(...),
373-
role='retry-prompt',
375+
role='user',
376+
message_kind='retry-prompt',
374377
),
375378
ModelResponse(
376379
parts=[
377380
ToolCallPart(
378381
tool_name='calc_volume',
379382
args=ArgsDict(args_dict={'size': 6}),
380383
tool_call_id=None,
381-
kind='tool-call',
384+
part_kind='tool-call',
382385
)
383386
],
384-
role='model-response',
385387
timestamp=datetime.datetime(...),
388+
role='model',
389+
message_kind='model-response',
386390
),
387391
]
388392
"""

docs/message-history.md

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,25 @@ print(result.data)
4040
print(result.all_messages())
4141
"""
4242
[
43-
SystemPrompt(content='Be a helpful assistant.', role='system'),
43+
SystemPrompt(
44+
content='Be a helpful assistant.', role='user', message_kind='system-prompt'
45+
),
4446
UserPrompt(
4547
content='Tell me a joke.',
4648
timestamp=datetime.datetime(...),
4749
role='user',
50+
message_kind='user-prompt',
4851
),
4952
ModelResponse(
5053
parts=[
5154
TextPart(
5255
content='Did you hear about the toothpaste scandal? They called it Colgate.',
53-
kind='text',
56+
part_kind='text',
5457
)
5558
],
56-
role='model-response',
5759
timestamp=datetime.datetime(...),
60+
role='model',
61+
message_kind='model-response',
5862
),
5963
]
6064
"""
@@ -67,16 +71,18 @@ print(result.new_messages())
6771
content='Tell me a joke.',
6872
timestamp=datetime.datetime(...),
6973
role='user',
74+
message_kind='user-prompt',
7075
),
7176
ModelResponse(
7277
parts=[
7378
TextPart(
7479
content='Did you hear about the toothpaste scandal? They called it Colgate.',
75-
kind='text',
80+
part_kind='text',
7681
)
7782
],
78-
role='model-response',
7983
timestamp=datetime.datetime(...),
84+
role='model',
85+
message_kind='model-response',
8086
),
8187
]
8288
"""
@@ -97,11 +103,16 @@ async def main():
97103
print(result.all_messages())
98104
"""
99105
[
100-
SystemPrompt(content='Be a helpful assistant.', role='system'),
106+
SystemPrompt(
107+
content='Be a helpful assistant.',
108+
role='user',
109+
message_kind='system-prompt',
110+
),
101111
UserPrompt(
102112
content='Tell me a joke.',
103113
timestamp=datetime.datetime(...),
104114
role='user',
115+
message_kind='user-prompt',
105116
),
106117
]
107118
"""
@@ -117,21 +128,27 @@ async def main():
117128
print(result.all_messages())
118129
"""
119130
[
120-
SystemPrompt(content='Be a helpful assistant.', role='system'),
131+
SystemPrompt(
132+
content='Be a helpful assistant.',
133+
role='user',
134+
message_kind='system-prompt',
135+
),
121136
UserPrompt(
122137
content='Tell me a joke.',
123138
timestamp=datetime.datetime(...),
124139
role='user',
140+
message_kind='user-prompt',
125141
),
126142
ModelResponse(
127143
parts=[
128144
TextPart(
129145
content='Did you hear about the toothpaste scandal? They called it Colgate.',
130-
kind='text',
146+
part_kind='text',
131147
)
132148
],
133-
role='model-response',
134149
timestamp=datetime.datetime(...),
150+
role='model',
151+
message_kind='model-response',
135152
),
136153
]
137154
"""
@@ -173,36 +190,42 @@ print(result2.data)
173190
print(result2.all_messages())
174191
"""
175192
[
176-
SystemPrompt(content='Be a helpful assistant.', role='system'),
193+
SystemPrompt(
194+
content='Be a helpful assistant.', role='user', message_kind='system-prompt'
195+
),
177196
UserPrompt(
178197
content='Tell me a joke.',
179198
timestamp=datetime.datetime(...),
180199
role='user',
200+
message_kind='user-prompt',
181201
),
182202
ModelResponse(
183203
parts=[
184204
TextPart(
185205
content='Did you hear about the toothpaste scandal? They called it Colgate.',
186-
kind='text',
206+
part_kind='text',
187207
)
188208
],
189-
role='model-response',
190209
timestamp=datetime.datetime(...),
210+
role='model',
211+
message_kind='model-response',
191212
),
192213
UserPrompt(
193214
content='Explain?',
194215
timestamp=datetime.datetime(...),
195216
role='user',
217+
message_kind='user-prompt',
196218
),
197219
ModelResponse(
198220
parts=[
199221
TextPart(
200222
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
201-
kind='text',
223+
part_kind='text',
202224
)
203225
],
204-
role='model-response',
205226
timestamp=datetime.datetime(...),
227+
role='model',
228+
message_kind='model-response',
206229
),
207230
]
208231
"""
@@ -233,36 +256,42 @@ print(result2.data)
233256
print(result2.all_messages())
234257
"""
235258
[
236-
SystemPrompt(content='Be a helpful assistant.', role='system'),
259+
SystemPrompt(
260+
content='Be a helpful assistant.', role='user', message_kind='system-prompt'
261+
),
237262
UserPrompt(
238263
content='Tell me a joke.',
239264
timestamp=datetime.datetime(...),
240265
role='user',
266+
message_kind='user-prompt',
241267
),
242268
ModelResponse(
243269
parts=[
244270
TextPart(
245271
content='Did you hear about the toothpaste scandal? They called it Colgate.',
246-
kind='text',
272+
part_kind='text',
247273
)
248274
],
249-
role='model-response',
250275
timestamp=datetime.datetime(...),
276+
role='model',
277+
message_kind='model-response',
251278
),
252279
UserPrompt(
253280
content='Explain?',
254281
timestamp=datetime.datetime(...),
255282
role='user',
283+
message_kind='user-prompt',
256284
),
257285
ModelResponse(
258286
parts=[
259287
TextPart(
260288
content='This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
261-
kind='text',
289+
part_kind='text',
262290
)
263291
],
264-
role='model-response',
265292
timestamp=datetime.datetime(...),
293+
role='model',
294+
message_kind='model-response',
266295
),
267296
]
268297
"""

docs/testing-evals.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,14 @@ async def test_forecast():
127127
assert weather_agent.last_run_messages == [ # (6)!
128128
SystemPrompt(
129129
content='Providing a weather forecast at the locations the user provides.',
130-
role='system',
130+
role='user',
131+
message_kind='system-prompt',
131132
),
132133
UserPrompt(
133134
content='What will the weather be like in London on 2024-11-28?',
134135
timestamp=IsNow(tz=timezone.utc), # (7)!
135136
role='user',
137+
message_kind='user-prompt',
136138
),
137139
ModelResponse(
138140
parts=[
@@ -148,14 +150,16 @@ async def test_forecast():
148150
)
149151
],
150152
timestamp=IsNow(tz=timezone.utc),
151-
role='model-response',
153+
role='model',
154+
message_kind='model-response',
152155
),
153156
ToolReturn(
154157
tool_name='weather_forecast',
155158
content='Sunny with a chance of rain',
156159
tool_call_id=None,
157160
timestamp=IsNow(tz=timezone.utc),
158-
role='tool-return',
161+
role='user',
162+
message_kind='tool-return',
159163
),
160164
ModelResponse(
161165
parts=[
@@ -164,7 +168,8 @@ async def test_forecast():
164168
)
165169
],
166170
timestamp=IsNow(tz=timezone.utc),
167-
role='model-response',
171+
role='model',
172+
message_kind='model-response',
168173
),
169174
]
170175
```
@@ -219,7 +224,7 @@ def call_weather_forecast( # (1)!
219224
else:
220225
# second call, return the forecast
221226
msg = messages[-1]
222-
assert msg.role == 'tool-return'
227+
assert msg.message_kind == 'tool-return'
223228
return ModelResponse.from_text(f'The forecast is: {msg.content}')
224229

225230

docs/tools.md

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,60 +70,67 @@ print(dice_result.all_messages())
7070
[
7171
SystemPrompt(
7272
content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.",
73-
role='system',
73+
role='user',
74+
message_kind='system-prompt',
7475
),
7576
UserPrompt(
7677
content='My guess is 4',
7778
timestamp=datetime.datetime(...),
7879
role='user',
80+
message_kind='user-prompt',
7981
),
8082
ModelResponse(
8183
parts=[
8284
ToolCallPart(
8385
tool_name='roll_die',
8486
args=ArgsDict(args_dict={}),
8587
tool_call_id=None,
86-
kind='tool-call',
88+
part_kind='tool-call',
8789
)
8890
],
89-
role='model-response',
9091
timestamp=datetime.datetime(...),
92+
role='model',
93+
message_kind='model-response',
9194
),
9295
ToolReturn(
9396
tool_name='roll_die',
9497
content='4',
9598
tool_call_id=None,
9699
timestamp=datetime.datetime(...),
97-
role='tool-return',
100+
role='user',
101+
message_kind='tool-return',
98102
),
99103
ModelResponse(
100104
parts=[
101105
ToolCallPart(
102106
tool_name='get_player_name',
103107
args=ArgsDict(args_dict={}),
104108
tool_call_id=None,
105-
kind='tool-call',
109+
part_kind='tool-call',
106110
)
107111
],
108-
role='model-response',
109112
timestamp=datetime.datetime(...),
113+
role='model',
114+
message_kind='model-response',
110115
),
111116
ToolReturn(
112117
tool_name='get_player_name',
113118
content='Anne',
114119
tool_call_id=None,
115120
timestamp=datetime.datetime(...),
116-
role='tool-return',
121+
role='user',
122+
message_kind='tool-return',
117123
),
118124
ModelResponse(
119125
parts=[
120126
TextPart(
121127
content="Congratulations Anne, you guessed correctly! You're a winner!",
122-
kind='text',
128+
part_kind='text',
123129
)
124130
],
125-
role='model-response',
126131
timestamp=datetime.datetime(...),
132+
role='model',
133+
message_kind='model-response',
127134
),
128135
]
129136
"""

pydantic_ai_examples/chat_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def stream_messages():
9898

9999

100100
MessageTypeAdapter: TypeAdapter[Message] = TypeAdapter(
101-
Annotated[Message, Field(discriminator='role')]
101+
Annotated[Message, Field(discriminator='message_kind')]
102102
)
103103
P = ParamSpec('P')
104104
R = TypeVar('R')

0 commit comments

Comments
 (0)