Skip to content

Commit 81bf883

Browse files
Removing from_text and from_tool_call utilities that complicate snapshot testing (#744)
1 parent 10b3c91 commit 81bf883

File tree

13 files changed

+53
-51
lines changed

13 files changed

+53
-51
lines changed

docs/api/models/function.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Here's a minimal example:
1111

1212
```py {title="function_model_usage.py" call_name="test_my_agent" noqa="I001"}
1313
from pydantic_ai import Agent
14-
from pydantic_ai.messages import ModelMessage, ModelResponse
14+
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart
1515
from pydantic_ai.models.function import FunctionModel, AgentInfo
1616

1717
my_agent = Agent('openai:gpt-4o')
@@ -41,7 +41,7 @@ async def model_function(
4141
function_tools=[], allow_text_result=True, result_tools=[], model_settings=None
4242
)
4343
"""
44-
return ModelResponse.from_text('hello world')
44+
return ModelResponse(parts=[TextPart('hello world')])
4545

4646

4747
async def test_my_agent():

docs/testing-evals.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ from pydantic_ai import models
202202
from pydantic_ai.messages import (
203203
ModelMessage,
204204
ModelResponse,
205+
TextPart,
205206
ToolCallPart,
206207
)
207208
from pydantic_ai.models.function import AgentInfo, FunctionModel
@@ -229,7 +230,7 @@ def call_weather_forecast( # (1)!
229230
# second call, return the forecast
230231
msg = messages[-1].parts[0]
231232
assert msg.part_kind == 'tool-return'
232-
return ModelResponse.from_text(f'The forecast is: {msg.content}')
233+
return ModelResponse(parts=[TextPart(f'The forecast is: {msg.content}')])
233234

234235

235236
async def test_forecast_future():

docs/tools.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models
243243

244244
```python {title="tool_schema.py"}
245245
from pydantic_ai import Agent
246-
from pydantic_ai.messages import ModelMessage, ModelResponse
246+
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart
247247
from pydantic_ai.models.function import AgentInfo, FunctionModel
248248

249249
agent = Agent()
@@ -283,7 +283,7 @@ def print_schema(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
283283
'additionalProperties': False,
284284
}
285285
"""
286-
return ModelResponse.from_text(content='foobar')
286+
return ModelResponse(parts=[TextPart('foobar')])
287287

288288

289289
agent.run_sync('hello', model=FunctionModel(print_schema))

examples/pydantic_ai_examples/chat_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ async def stream_messages():
128128
async for text in result.stream(debounce_by=0.01):
129129
# text here is a `str` and the frontend wants
130130
# JSON encoded ModelResponse, so we create one
131-
m = ModelResponse.from_text(content=text, timestamp=result.timestamp())
131+
m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp())
132132
yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
133133

134134
# add new messages (e.g. the user prompt and the agent response in this case) to the database

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,6 @@ class ModelResponse:
264264
kind: Literal['response'] = 'response'
265265
"""Message type identifier, this is available on all parts as a discriminator."""
266266

267-
@classmethod
268-
def from_text(cls, content: str, model_name: str | None = None, timestamp: datetime | None = None) -> Self:
269-
"""Create a `ModelResponse` containing a single `TextPart`."""
270-
return cls([TextPart(content=content)], model_name=model_name, timestamp=timestamp or _now_utc())
271-
272-
@classmethod
273-
def from_tool_call(cls, tool_call: ToolCallPart, model_name: str | None = None) -> Self:
274-
"""Create a `ModelResponse` containing a single `ToolCallPart`."""
275-
return cls([tool_call], model_name=model_name)
276-
277267

278268
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
279269
"""Any message sent to or returned by a model."""

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,13 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
191191
if isinstance(part, ToolReturnPart):
192192
output[part.tool_name] = part.content
193193
if output:
194-
return ModelResponse.from_text(pydantic_core.to_json(output).decode(), model_name=self.model_name)
194+
return ModelResponse(
195+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
196+
)
195197
else:
196-
return ModelResponse.from_text('success (no tool calls)', model_name=self.model_name)
198+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
197199
else:
198-
return ModelResponse.from_text(response_text.value, model_name=self.model_name)
200+
return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
199201
else:
200202
assert self.result_tools, 'No result tools provided'
201203
custom_result_args = self.result.right

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ async def _stream_text_deltas() -> AsyncIterator[str]:
270270

271271
lf_span.set_attribute('combined_text', combined_validated_text)
272272
await self._marked_completed(
273-
_messages.ModelResponse.from_text(combined_validated_text, self._stream_response.model_name())
273+
_messages.ModelResponse(
274+
parts=[_messages.TextPart(combined_validated_text)],
275+
model_name=self._stream_response.model_name(),
276+
)
274277
)
275278

276279
async def stream_structured(

tests/models/test_cohere.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ModelResponse,
1515
RetryPromptPart,
1616
SystemPromptPart,
17+
TextPart,
1718
ToolCallPart,
1819
ToolReturnPart,
1920
UserPromptPart,
@@ -101,12 +102,12 @@ async def test_request_simple_success(allow_model_requests: None):
101102
assert result.all_messages() == snapshot(
102103
[
103104
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
104-
ModelResponse.from_text(
105-
content='world', model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
105+
ModelResponse(
106+
parts=[TextPart('world')], model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
106107
),
107108
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
108-
ModelResponse.from_text(
109-
content='world', model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
109+
ModelResponse(
110+
parts=[TextPart('world')], model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
110111
),
111112
]
112113
)
@@ -294,8 +295,8 @@ async def get_location(loc_name: str) -> str:
294295
)
295296
]
296297
),
297-
ModelResponse.from_text(
298-
content='final response', model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
298+
ModelResponse(
299+
parts=[TextPart('final response')], model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc)
299300
),
300301
]
301302
)

tests/models/test_gemini.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def example_usage() -> _GeminiUsageMetaData:
435435

436436

437437
async def test_text_success(get_gemini_client: GetGeminiClient):
438-
response = gemini_response(_content_model_response(ModelResponse.from_text('Hello world')))
438+
response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')])))
439439
gemini_client = get_gemini_client(response)
440440
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
441441
agent = Agent(m)
@@ -525,7 +525,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
525525
)
526526
)
527527
),
528-
gemini_response(_content_model_response(ModelResponse.from_text('final response'))),
528+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('final response')]))),
529529
]
530530
gemini_client = get_gemini_client(responses)
531531
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
@@ -621,8 +621,8 @@ def handler(_: httpx.Request):
621621

622622
async def test_stream_text(get_gemini_client: GetGeminiClient):
623623
responses = [
624-
gemini_response(_content_model_response(ModelResponse.from_text('Hello '))),
625-
gemini_response(_content_model_response(ModelResponse.from_text('world'))),
624+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))),
625+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
626626
]
627627
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
628628
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
@@ -763,7 +763,7 @@ async def bar(y: str) -> str:
763763

764764
async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
765765
responses = [
766-
gemini_response(_content_model_response(ModelResponse.from_text('Hello '))),
766+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))),
767767
gemini_response(
768768
_GeminiContent(
769769
role='model',

tests/models/test_model_function.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
def hello(_messages: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse:
34-
return ModelResponse.from_text('hello world') # pragma: no cover
34+
return ModelResponse(parts=[TextPart('hello world')]) # pragma: no cover
3535

3636

3737
async def stream_hello(_messages: list[ModelMessage], _agent_info: AgentInfo) -> AsyncIterator[str]:
@@ -55,7 +55,7 @@ async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelRespon
5555
response = asdict(last)
5656
response.pop('timestamp', None)
5757
response['message_count'] = len(messages)
58-
return ModelResponse.from_text(' '.join(f'{k}={v!r}' for k, v in response.items()))
58+
return ModelResponse(parts=[TextPart(' '.join(f'{k}={v!r}' for k, v in response.items()))])
5959

6060

6161
def test_simple(set_event_loop: None):
@@ -117,7 +117,7 @@ async def weather_model(messages: list[ModelMessage], info: AgentInfo) -> ModelR
117117
break
118118

119119
assert location_name is not None
120-
return ModelResponse.from_text(f'{last.content} in {location_name}')
120+
return ModelResponse(parts=[TextPart(f'{last.content} in {location_name}')])
121121

122122
raise ValueError(f'Unexpected message: {last}')
123123

@@ -200,7 +200,7 @@ async def call_function_model(messages: list[ModelMessage], _: AgentInfo) -> Mod
200200
]
201201
)
202202
elif isinstance(last, ToolReturnPart):
203-
return ModelResponse.from_text(pydantic_core.to_json(last).decode())
203+
return ModelResponse(parts=[TextPart(pydantic_core.to_json(last).decode())])
204204

205205
raise ValueError(f'Unexpected message: {last}')
206206

@@ -236,7 +236,7 @@ async def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRespo
236236
tool_name = info.function_tools[0].name
237237
return ModelResponse(parts=[ToolCallPart.from_raw_args(tool_name, '{}')])
238238
else:
239-
return ModelResponse.from_text('final response')
239+
return ModelResponse(parts=[TextPart('final response')])
240240

241241

242242
def test_deps_none(set_event_loop: None):
@@ -318,8 +318,12 @@ def spam() -> str:
318318

319319
def test_register_all(set_event_loop: None):
320320
async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
321-
return ModelResponse.from_text(
322-
f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
321+
return ModelResponse(
322+
parts=[
323+
TextPart(
324+
f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
325+
)
326+
]
323327
)
324328

325329
result = agent_all.run_sync('Hello', model=FunctionModel(f))
@@ -373,7 +377,7 @@ async def try_again(msgs_: list[ModelMessage], _agent_info: AgentInfo) -> ModelR
373377
nonlocal call_count
374378
call_count += 1
375379

376-
return ModelResponse.from_text(str(call_count))
380+
return ModelResponse(parts=[TextPart(str(call_count))])
377381

378382
agent = Agent(FunctionModel(try_again))
379383

0 commit comments

Comments
 (0)