Skip to content

Commit 6d7c53a

Browse files
samuelcolvinKludex
andauthored
Add Gemini 2.5 pro and improve CLI (#1317)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 07e335c commit 6d7c53a

File tree

4 files changed

+203
-64
lines changed

4 files changed

+203
-64
lines changed

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 82 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
import argparse
44
import asyncio
55
import sys
6+
from asyncio import CancelledError
67
from collections.abc import Sequence
78
from contextlib import ExitStack
89
from datetime import datetime, timezone
910
from importlib.metadata import version
1011
from pathlib import Path
11-
from typing import cast
12+
from typing import Any, cast
1213

1314
from typing_inspection.introspection import get_literal_values
1415

1516
from pydantic_ai.agent import Agent
1617
from pydantic_ai.exceptions import UserError
1718
from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
18-
from pydantic_ai.models import KnownModelName
19+
from pydantic_ai.models import KnownModelName, infer_model
1920

2021
try:
2122
import argcomplete
@@ -47,7 +48,7 @@ class SimpleCodeBlock(CodeBlock):
4748
This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix.
4849
"""
4950

50-
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
51+
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
5152
code = str(self.text).rstrip()
5253
yield Text(self.lexer_name, style='dim')
5354
yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True)
@@ -57,7 +58,7 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR
5758
class LeftHeading(Heading):
5859
"""Customised headings in markdown to stop centering and prepend markdown style hashes."""
5960

60-
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
61+
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
6162
# note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO
6263
yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True))
6364

@@ -68,7 +69,21 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR
6869
)
6970

7071

71-
def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: no cover
72+
cli_agent = Agent()
73+
74+
75+
@cli_agent.system_prompt
76+
def cli_system_prompt() -> str:
77+
now_utc = datetime.now(timezone.utc)
78+
tzinfo = now_utc.astimezone().tzinfo
79+
tzname = tzinfo.tzname(now_utc) if tzinfo else ''
80+
return f"""\
81+
Help the user by responding to their request, the output should be concise and always written in markdown.
82+
The current date and time is {datetime.now()} {tzname}.
83+
The user is running {sys.platform}."""
84+
85+
86+
def cli(args_list: Sequence[str] | None = None) -> int:
7287
parser = argparse.ArgumentParser(
7388
prog='pai',
7489
description=f"""\
@@ -124,18 +139,10 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
124139
console.print(f' {model}', highlight=False)
125140
return 0
126141

127-
now_utc = datetime.now(timezone.utc)
128-
tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
129142
try:
130-
agent = Agent(
131-
model=args.model,
132-
system_prompt=f"""\
133-
Help the user by responding to their request, the output should be concise and always written in markdown.
134-
The current date and time is {datetime.now()} {tzname}.
135-
The user is running {sys.platform}.""",
136-
)
137-
except UserError:
138-
console.print(f'[red]Invalid model "{args.model}"[/red]')
143+
cli_agent.model = infer_model(args.model)
144+
except UserError as e:
145+
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
139146
return 1
140147

141148
stream = not args.no_stream
@@ -148,67 +155,44 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
148155

149156
if prompt := cast(str, args.prompt):
150157
try:
151-
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
158+
asyncio.run(ask_agent(cli_agent, prompt, stream, console, code_theme))
152159
except KeyboardInterrupt:
153160
pass
154161
return 0
155162

156163
history = Path.home() / '.pai-prompt-history.txt'
157-
session = PromptSession(history=FileHistory(str(history))) # type: ignore
164+
# doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
165+
session: PromptSession[Any] = PromptSession(history=FileHistory(str(history)))
166+
try:
167+
return asyncio.run(run_chat(session, stream, cli_agent, console, code_theme))
168+
except KeyboardInterrupt: # pragma: no cover
169+
return 0
170+
171+
172+
async def run_chat(session: PromptSession[Any], stream: bool, agent: Agent, console: Console, code_theme: str) -> int:
158173
multiline = False
159174
messages: list[ModelMessage] = []
160175

161176
while True:
162177
try:
163178
auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit'])
164-
text = cast(str, session.prompt('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline))
165-
except (KeyboardInterrupt, EOFError):
179+
text = await session.prompt_async('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline)
180+
except (KeyboardInterrupt, EOFError): # pragma: no cover
166181
return 0
167182

168183
if not text.strip():
169184
continue
170185

171-
ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ')
186+
ident_prompt = text.lower().strip().replace(' ', '-')
172187
if ident_prompt.startswith('/'):
173-
if ident_prompt == '/markdown':
174-
try:
175-
parts = messages[-1].parts
176-
except IndexError:
177-
console.print('[dim]No markdown output available.[/dim]')
178-
continue
179-
console.print('[dim]Markdown output of last question:[/dim]\n')
180-
for part in parts:
181-
if part.part_kind == 'text':
182-
console.print(
183-
Syntax(
184-
part.content,
185-
lexer='markdown',
186-
theme=code_theme,
187-
word_wrap=True,
188-
background_color='default',
189-
)
190-
)
191-
192-
elif ident_prompt == '/multiline':
193-
multiline = not multiline
194-
if multiline:
195-
console.print(
196-
'Enabling multiline mode. '
197-
'[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
198-
)
199-
else:
200-
console.print('Disabling multiline mode.')
201-
elif ident_prompt == '/exit':
202-
console.print('[dim]Exiting…[/dim]')
203-
return 0
204-
else:
205-
console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
188+
exit_value, multiline = handle_slash_command(ident_prompt, messages, multiline, console, code_theme)
189+
if exit_value is not None:
190+
return exit_value
206191
else:
207192
try:
208-
messages = asyncio.run(ask_agent(agent, text, stream, console, code_theme, messages))
209-
except KeyboardInterrupt:
193+
messages = await ask_agent(agent, text, stream, console, code_theme, messages)
194+
except CancelledError: # pragma: no cover
210195
console.print('[dim]Interrupted[/dim]')
211-
messages = []
212196

213197

214198
async def ask_agent(
@@ -218,7 +202,7 @@ async def ask_agent(
218202
console: Console,
219203
code_theme: str,
220204
messages: list[ModelMessage] | None = None,
221-
) -> list[ModelMessage]: # pragma: no cover
205+
) -> list[ModelMessage]:
222206
status = Status('[dim]Working on it…[/dim]', console=console)
223207

224208
if not stream:
@@ -248,7 +232,7 @@ async def ask_agent(
248232

249233

250234
class CustomAutoSuggest(AutoSuggestFromHistory):
251-
def __init__(self, special_suggestions: list[str] | None = None): # pragma: no cover
235+
def __init__(self, special_suggestions: list[str] | None = None):
252236
super().__init__()
253237
self.special_suggestions = special_suggestions or []
254238

@@ -264,5 +248,44 @@ def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | Non
264248
return suggestion
265249

266250

251+
def handle_slash_command(
252+
ident_prompt: str, messages: list[ModelMessage], multiline: bool, console: Console, code_theme: str
253+
) -> tuple[int | None, bool]:
254+
if ident_prompt == '/markdown':
255+
try:
256+
parts = messages[-1].parts
257+
except IndexError:
258+
console.print('[dim]No markdown output available.[/dim]')
259+
else:
260+
console.print('[dim]Markdown output of last question:[/dim]\n')
261+
for part in parts:
262+
if part.part_kind == 'text':
263+
console.print(
264+
Syntax(
265+
part.content,
266+
lexer='markdown',
267+
theme=code_theme,
268+
word_wrap=True,
269+
background_color='default',
270+
)
271+
)
272+
273+
elif ident_prompt == '/multiline':
274+
multiline = not multiline
275+
if multiline:
276+
console.print(
277+
'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
278+
)
279+
else:
280+
console.print('Disabling multiline mode.')
281+
return None, multiline
282+
elif ident_prompt == '/exit':
283+
console.print('[dim]Exiting…[/dim]')
284+
return 0, multiline
285+
else:
286+
console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
287+
return None, multiline
288+
289+
267290
def app(): # pragma: no cover
268291
sys.exit(cli())

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
'google-gla:gemini-2.0-flash',
107107
'google-gla:gemini-2.0-flash-lite-preview-02-05',
108108
'google-gla:gemini-2.0-pro-exp-02-05',
109+
'google-gla:gemini-2.5-pro-exp-03-25',
109110
'google-vertex:gemini-1.0-pro',
110111
'google-vertex:gemini-1.5-flash',
111112
'google-vertex:gemini-1.5-flash-8b',
@@ -116,6 +117,7 @@
116117
'google-vertex:gemini-2.0-flash',
117118
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
118119
'google-vertex:gemini-2.0-pro-exp-02-05',
120+
'google-vertex:gemini-2.5-pro-exp-03-25',
119121
'gpt-3.5-turbo',
120122
'gpt-3.5-turbo-0125',
121123
'gpt-3.5-turbo-0301',

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
'gemini-2.0-flash',
5858
'gemini-2.0-flash-lite-preview-02-05',
5959
'gemini-2.0-pro-exp-02-05',
60+
'gemini-2.5-pro-exp-03-25',
6061
]
6162
"""Latest Gemini models."""
6263

0 commit comments

Comments
 (0)