Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,15 @@ async def run_chat(
prog_name: str,
config_dir: Path | None = None,
deps: AgentDepsT = None,
message_history: list[ModelMessage] | None = None,
message_history: Sequence[ModelMessage] | None = None,
) -> int:
prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
prompt_history_path.touch(exist_ok=True)
session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path)))

multiline = False
messages: list[ModelMessage] = message_history[:] if message_history else []
messages: list[ModelMessage] = list(message_history) if message_history else []

while True:
try:
Expand Down Expand Up @@ -272,7 +272,7 @@ async def ask_agent(
console: Console,
code_theme: str,
deps: AgentDepsT = None,
messages: list[ModelMessage] | None = None,
messages: Sequence[ModelMessage] | None = None,
) -> list[ModelMessage]:
status = Status('[dim]Working on it…[/dim]', console=console)

Expand Down
12 changes: 6 additions & 6 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -432,7 +432,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -449,7 +449,7 @@ async def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down Expand Up @@ -566,7 +566,7 @@ async def main():
# Build the initial state
usage = usage or _usage.RunUsage()
state = _agent_graph.GraphAgentState(
message_history=message_history[:] if message_history else [],
message_history=list(message_history) if message_history else [],
usage=usage,
retries=0,
run_step=0,
Expand Down Expand Up @@ -684,13 +684,13 @@ def _run_span_end_attributes(
'all_messages_events': json.dumps(
[
InstrumentedModel.event_to_dict(e)
for e in settings.messages_to_otel_events(state.message_history)
for e in settings.messages_to_otel_events(list(state.message_history))
]
)
}
else:
attrs = {
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(list(state.message_history))),
**settings.system_instructions_attributes(literal_instructions),
}

Expand Down
32 changes: 16 additions & 16 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def run(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -144,7 +144,7 @@ async def run(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -161,7 +161,7 @@ async def run(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down Expand Up @@ -240,7 +240,7 @@ def run_sync(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -258,7 +258,7 @@ def run_sync(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -275,7 +275,7 @@ def run_sync(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down Expand Up @@ -346,7 +346,7 @@ def run_stream(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -364,7 +364,7 @@ def run_stream(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -382,7 +382,7 @@ async def run_stream( # noqa C901
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down Expand Up @@ -493,7 +493,7 @@ async def stream_to_final(
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
yielded = True

messages = graph_ctx.state.message_history.copy()
messages = list(graph_ctx.state.message_history)

async def on_complete() -> None:
"""Called when the stream has completed.
Expand Down Expand Up @@ -537,7 +537,7 @@ async def on_complete() -> None:
# if a tool function raised CallDeferred or ApprovalRequired.
# In this case there's no response to stream, but we still let the user access the output etc as normal.
yield StreamedRunResult(
graph_ctx.state.message_history,
list(graph_ctx.state.message_history),
graph_ctx.deps.new_message_index,
run_result=agent_run.result,
)
Expand All @@ -558,7 +558,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -575,7 +575,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -593,7 +593,7 @@ async def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down Expand Up @@ -944,7 +944,7 @@ async def to_cli(
self: Self,
deps: AgentDepsT = None,
prog_name: str = 'pydantic-ai',
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
) -> None:
"""Run the agent in a CLI chat interface.

Expand Down Expand Up @@ -981,7 +981,7 @@ def to_cli_sync(
self: Self,
deps: AgentDepsT = None,
prog_name: str = 'pydantic-ai',
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
) -> None:
"""Run the agent in a CLI chat interface with the non-async interface.

Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -89,7 +89,7 @@ def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand All @@ -106,7 +106,7 @@ async def iter(
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
message_history: Sequence[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
Expand Down
18 changes: 9 additions & 9 deletions pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import queue
import threading
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from datetime import datetime
Expand All @@ -35,7 +35,7 @@

async def model_request(
model: models.Model | models.KnownModelName | str,
messages: list[messages.ModelMessage],
messages: Sequence[messages.ModelMessage],
*,
model_settings: settings.ModelSettings | None = None,
model_request_parameters: models.ModelRequestParameters | None = None,
Expand Down Expand Up @@ -79,15 +79,15 @@ async def main():
"""
model_instance = _prepare_model(model, instrument)
return await model_instance.request(
messages,
list(messages),
model_settings,
model_request_parameters or models.ModelRequestParameters(),
)


def model_request_sync(
model: models.Model | models.KnownModelName | str,
messages: list[messages.ModelMessage],
messages: Sequence[messages.ModelMessage],
*,
model_settings: settings.ModelSettings | None = None,
model_request_parameters: models.ModelRequestParameters | None = None,
Expand Down Expand Up @@ -133,7 +133,7 @@ def model_request_sync(
return _get_event_loop().run_until_complete(
model_request(
model,
messages,
list(messages),
model_settings=model_settings,
model_request_parameters=model_request_parameters,
instrument=instrument,
Expand All @@ -143,7 +143,7 @@ def model_request_sync(

def model_request_stream(
model: models.Model | models.KnownModelName | str,
messages: list[messages.ModelMessage],
messages: Sequence[messages.ModelMessage],
*,
model_settings: settings.ModelSettings | None = None,
model_request_parameters: models.ModelRequestParameters | None = None,
Expand Down Expand Up @@ -191,15 +191,15 @@ async def main():
"""
model_instance = _prepare_model(model, instrument)
return model_instance.request_stream(
messages,
list(messages),
model_settings,
model_request_parameters or models.ModelRequestParameters(),
)


def model_request_stream_sync(
model: models.Model | models.KnownModelName | str,
messages: list[messages.ModelMessage],
messages: Sequence[messages.ModelMessage],
*,
model_settings: settings.ModelSettings | None = None,
model_request_parameters: models.ModelRequestParameters | None = None,
Expand Down Expand Up @@ -246,7 +246,7 @@ def model_request_stream_sync(
"""
async_stream_cm = model_request_stream(
model=model,
messages=messages,
messages=list(messages),
model_settings=model_settings,
model_request_parameters=model_request_parameters,
instrument=instrument,
Expand Down
Loading