3
3
import argparse
4
4
import asyncio
5
5
import sys
6
+ from asyncio import CancelledError
6
7
from collections .abc import Sequence
7
8
from contextlib import ExitStack
8
9
from datetime import datetime , timezone
9
10
from importlib .metadata import version
10
11
from pathlib import Path
11
- from typing import cast
12
+ from typing import Any , cast
12
13
13
14
from typing_inspection .introspection import get_literal_values
14
15
15
16
from pydantic_ai .agent import Agent
16
17
from pydantic_ai .exceptions import UserError
17
18
from pydantic_ai .messages import ModelMessage , PartDeltaEvent , TextPartDelta
18
- from pydantic_ai .models import KnownModelName
19
+ from pydantic_ai .models import KnownModelName , infer_model
19
20
20
21
try :
21
22
import argcomplete
@@ -47,7 +48,7 @@ class SimpleCodeBlock(CodeBlock):
47
48
This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix.
48
49
"""
49
50
50
- def __rich_console__ (self , console : Console , options : ConsoleOptions ) -> RenderResult : # pragma: no cover
51
+ def __rich_console__ (self , console : Console , options : ConsoleOptions ) -> RenderResult :
51
52
code = str (self .text ).rstrip ()
52
53
yield Text (self .lexer_name , style = 'dim' )
53
54
yield Syntax (code , self .lexer_name , theme = self .theme , background_color = 'default' , word_wrap = True )
@@ -57,7 +58,7 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR
57
58
class LeftHeading (Heading ):
58
59
"""Customised headings in markdown to stop centering and prepend markdown style hashes."""
59
60
60
- def __rich_console__ (self , console : Console , options : ConsoleOptions ) -> RenderResult : # pragma: no cover
61
+ def __rich_console__ (self , console : Console , options : ConsoleOptions ) -> RenderResult :
61
62
# note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO
62
63
yield Text (f'{ "#" * int (self .tag [1 :])} { self .text .plain } ' , style = Style (bold = True ))
63
64
@@ -68,7 +69,21 @@ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderR
68
69
)
69
70
70
71
71
- def cli (args_list : Sequence [str ] | None = None ) -> int : # noqa: C901 # pragma: no cover
72
+ cli_agent = Agent ()
73
+
74
+
75
+ @cli_agent .system_prompt
76
+ def cli_system_prompt () -> str :
77
+ now_utc = datetime .now (timezone .utc )
78
+ tzinfo = now_utc .astimezone ().tzinfo
79
+ tzname = tzinfo .tzname (now_utc ) if tzinfo else ''
80
+ return f"""\
81
+ Help the user by responding to their request, the output should be concise and always written in markdown.
82
+ The current date and time is { datetime .now ()} { tzname } .
83
+ The user is running { sys .platform } ."""
84
+
85
+
86
+ def cli (args_list : Sequence [str ] | None = None ) -> int :
72
87
parser = argparse .ArgumentParser (
73
88
prog = 'pai' ,
74
89
description = f"""\
@@ -124,18 +139,10 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
124
139
console .print (f' { model } ' , highlight = False )
125
140
return 0
126
141
127
- now_utc = datetime .now (timezone .utc )
128
- tzname = now_utc .astimezone ().tzinfo .tzname (now_utc ) # type: ignore
129
142
try :
130
- agent = Agent (
131
- model = args .model ,
132
- system_prompt = f"""\
133
- Help the user by responding to their request, the output should be concise and always written in markdown.
134
- The current date and time is { datetime .now ()} { tzname } .
135
- The user is running { sys .platform } .""" ,
136
- )
137
- except UserError :
138
- console .print (f'[red]Invalid model "{ args .model } "[/red]' )
143
+ cli_agent .model = infer_model (args .model )
144
+ except UserError as e :
145
+ console .print (f'Error initializing [magenta]{ args .model } [/magenta]:\n [red]{ e } [/red]' )
139
146
return 1
140
147
141
148
stream = not args .no_stream
@@ -148,67 +155,44 @@ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma:
148
155
149
156
if prompt := cast (str , args .prompt ):
150
157
try :
151
- asyncio .run (ask_agent (agent , prompt , stream , console , code_theme ))
158
+ asyncio .run (ask_agent (cli_agent , prompt , stream , console , code_theme ))
152
159
except KeyboardInterrupt :
153
160
pass
154
161
return 0
155
162
156
163
history = Path .home () / '.pai-prompt-history.txt'
157
- session = PromptSession (history = FileHistory (str (history ))) # type: ignore
164
+ # doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
165
+ session : PromptSession [Any ] = PromptSession (history = FileHistory (str (history )))
166
+ try :
167
+ return asyncio .run (run_chat (session , stream , cli_agent , console , code_theme ))
168
+ except KeyboardInterrupt : # pragma: no cover
169
+ return 0
170
+
171
+
172
+ async def run_chat (session : PromptSession [Any ], stream : bool , agent : Agent , console : Console , code_theme : str ) -> int :
158
173
multiline = False
159
174
messages : list [ModelMessage ] = []
160
175
161
176
while True :
162
177
try :
163
178
auto_suggest = CustomAutoSuggest (['/markdown' , '/multiline' , '/exit' ])
164
- text = cast ( str , session .prompt ('pai ➤ ' , auto_suggest = auto_suggest , multiline = multiline ) )
165
- except (KeyboardInterrupt , EOFError ):
179
+ text = await session .prompt_async ('pai ➤ ' , auto_suggest = auto_suggest , multiline = multiline )
180
+ except (KeyboardInterrupt , EOFError ): # pragma: no cover
166
181
return 0
167
182
168
183
if not text .strip ():
169
184
continue
170
185
171
- ident_prompt = text .lower ().strip (' ' ).replace (' ' , '-' ). lstrip ( ' ' )
186
+ ident_prompt = text .lower ().strip ().replace (' ' , '-' )
172
187
if ident_prompt .startswith ('/' ):
173
- if ident_prompt == '/markdown' :
174
- try :
175
- parts = messages [- 1 ].parts
176
- except IndexError :
177
- console .print ('[dim]No markdown output available.[/dim]' )
178
- continue
179
- console .print ('[dim]Markdown output of last question:[/dim]\n ' )
180
- for part in parts :
181
- if part .part_kind == 'text' :
182
- console .print (
183
- Syntax (
184
- part .content ,
185
- lexer = 'markdown' ,
186
- theme = code_theme ,
187
- word_wrap = True ,
188
- background_color = 'default' ,
189
- )
190
- )
191
-
192
- elif ident_prompt == '/multiline' :
193
- multiline = not multiline
194
- if multiline :
195
- console .print (
196
- 'Enabling multiline mode. '
197
- '[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
198
- )
199
- else :
200
- console .print ('Disabling multiline mode.' )
201
- elif ident_prompt == '/exit' :
202
- console .print ('[dim]Exiting…[/dim]' )
203
- return 0
204
- else :
205
- console .print (f'[red]Unknown command[/red] [magenta]`{ ident_prompt } `[/magenta]' )
188
+ exit_value , multiline = handle_slash_command (ident_prompt , messages , multiline , console , code_theme )
189
+ if exit_value is not None :
190
+ return exit_value
206
191
else :
207
192
try :
208
- messages = asyncio . run ( ask_agent (agent , text , stream , console , code_theme , messages ) )
209
- except KeyboardInterrupt :
193
+ messages = await ask_agent (agent , text , stream , console , code_theme , messages )
194
+ except CancelledError : # pragma: no cover
210
195
console .print ('[dim]Interrupted[/dim]' )
211
- messages = []
212
196
213
197
214
198
async def ask_agent (
@@ -218,7 +202,7 @@ async def ask_agent(
218
202
console : Console ,
219
203
code_theme : str ,
220
204
messages : list [ModelMessage ] | None = None ,
221
- ) -> list [ModelMessage ]: # pragma: no cover
205
+ ) -> list [ModelMessage ]:
222
206
status = Status ('[dim]Working on it…[/dim]' , console = console )
223
207
224
208
if not stream :
@@ -248,7 +232,7 @@ async def ask_agent(
248
232
249
233
250
234
class CustomAutoSuggest (AutoSuggestFromHistory ):
251
- def __init__ (self , special_suggestions : list [str ] | None = None ): # pragma: no cover
235
+ def __init__ (self , special_suggestions : list [str ] | None = None ):
252
236
super ().__init__ ()
253
237
self .special_suggestions = special_suggestions or []
254
238
@@ -264,5 +248,44 @@ def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | Non
264
248
return suggestion
265
249
266
250
251
+ def handle_slash_command (
252
+ ident_prompt : str , messages : list [ModelMessage ], multiline : bool , console : Console , code_theme : str
253
+ ) -> tuple [int | None , bool ]:
254
+ if ident_prompt == '/markdown' :
255
+ try :
256
+ parts = messages [- 1 ].parts
257
+ except IndexError :
258
+ console .print ('[dim]No markdown output available.[/dim]' )
259
+ else :
260
+ console .print ('[dim]Markdown output of last question:[/dim]\n ' )
261
+ for part in parts :
262
+ if part .part_kind == 'text' :
263
+ console .print (
264
+ Syntax (
265
+ part .content ,
266
+ lexer = 'markdown' ,
267
+ theme = code_theme ,
268
+ word_wrap = True ,
269
+ background_color = 'default' ,
270
+ )
271
+ )
272
+
273
+ elif ident_prompt == '/multiline' :
274
+ multiline = not multiline
275
+ if multiline :
276
+ console .print (
277
+ 'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
278
+ )
279
+ else :
280
+ console .print ('Disabling multiline mode.' )
281
+ return None , multiline
282
+ elif ident_prompt == '/exit' :
283
+ console .print ('[dim]Exiting…[/dim]' )
284
+ return 0 , multiline
285
+ else :
286
+ console .print (f'[red]Unknown command[/red] [magenta]`{ ident_prompt } `[/magenta]' )
287
+ return None , multiline
288
+
289
+
267
290
def app (): # pragma: no cover
268
291
sys .exit (cli ())
0 commit comments