1+ import json
2+ import logging
13import openai
24from pathlib import PurePosixPath
5+ import textwrap
6+ from typing import Any , Mapping , Self , Sequence , override
37
48from .common import Assistant , Session , Toolbox
59
610
7- # https://aider.chat/docs/more-info.html
8- # https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
9- _INSTRUCTIONS = """\
10- You are an expert software engineer, who writes correct and concise code.
11- """
11+ _logger = logging .getLogger (__name__ )
1212
13- _tools = [ # TODO
14- {
13+
14+ def _function_tool_param (
15+ name : str ,
16+ description : str ,
17+ inputs : Mapping [str , Any ] | None = None ,
18+ required_inputs : Sequence [str ] | None = None ,
19+ ) -> openai .types .beta .FunctionToolParam :
20+ return {
1521 "type" : "function" ,
1622 "function" : {
17- "name" : "read_file" ,
18- "description" : "Get a file's contents" ,
23+ "name" : name ,
24+ "description" : textwrap . dedent ( description ) ,
1925 "parameters" : {
2026 "type" : "object" ,
21- "properties" : {
22- "path" : {
23- "type" : "string" ,
24- "description" : "Path of the file to be read" ,
25- },
26- },
27- "required" : ["path" ],
27+ "additionalProperties" : False ,
28+ "properties" : inputs or {},
29+ "required" : required_inputs or [],
2830 },
31+ "strict" : True ,
2932 },
30- },
31- {
32- "type" : "function" ,
33- "function" : {
34- "name" : "write_file" ,
35- "description" : "Update a file's contents" ,
36- "parameters" : {
37- "type" : "object" ,
38- "properties" : {
39- "path" : {
40- "type" : "string" ,
41- "description" : "Path of the file to be updated" ,
42- },
43- "contents" : {
44- "type" : "string" ,
45- "description" : "New contents of the file" ,
46- },
47- },
48- "required" : ["path" , "contents" ],
33+ }
34+
35+
36+ _tools = [
37+ _function_tool_param (
38+ name = "list_files" ,
39+ description = "List all available files" ,
40+ ),
41+ _function_tool_param (
42+ name = "read_file" ,
43+ description = "Get a file's contents" ,
44+ inputs = {
45+ "path" : {
46+ "type" : "string" ,
47+ "description" : "Path of the file to be read" ,
48+ },
49+ },
50+ required_inputs = ["path" ],
51+ ),
52+ _function_tool_param (
53+ name = "write_file" ,
54+ description = """\
55+ Set a file's contents
56+
57+ The file will be created if it does not already exist.
58+ """ ,
59+ inputs = {
60+ "path" : {
61+ "type" : "string" ,
62+ "description" : "Path of the file to be updated" ,
63+ },
64+ "contents" : {
65+ "type" : "string" ,
66+ "description" : "New contents of the file" ,
4967 },
5068 },
51- },
69+ required_inputs = ["path" , "contents" ],
70+ ),
5271]
5372
5473
74+ # https://aider.chat/docs/more-info.html
75+ # https://github.com/Aider-AI/aider/blob/main/aider/prompts.py
76+ _INSTRUCTIONS = """\
77+ You are an expert software engineer, who writes correct and concise code.
78+ Use the provided functions to find the filesyou need to answer the query,
79+ read the content of the relevant ones, and save the changes you suggest.
80+ """
81+
82+
5583class OpenAIAssistant (Assistant ):
5684 """An OpenAI-backed assistant
5785
@@ -66,26 +94,73 @@ def __init__(self) -> None:
6694 self ._client = openai .OpenAI ()
6795
6896 def run (self , prompt : str , toolbox : Toolbox ) -> Session :
69- # TODO: Switch to the thread run API, using tools to leverage toolbox
70- # methods.
71- # assistant = client.beta.assistants.create(
72- # instructions=_INSTRUCTIONS,
73- # model="gpt-4o",
74- # tools=_tools,
75- # )
76- # thread = client.beta.threads.create()
77- # message = client.beta.threads.messages.create(
78- # thread_id=thread.id,
79- # role="user",
80- # content="What's the weather in San Francisco today and the likelihood it'll rain?",
81- # )
82- completion = self ._client .chat .completions .create (
83- messages = [
84- {"role" : "system" , "content" : _INSTRUCTIONS },
85- {"role" : "user" , "content" : prompt },
86- ],
97+ # TODO: Reuse assistant.
98+ assistant = self ._client .beta .assistants .create (
99+ instructions = _INSTRUCTIONS ,
87100 model = "gpt-4o" ,
101+ tools = _tools ,
102+ )
103+ thread = self ._client .beta .threads .create ()
104+
105+ message = self ._client .beta .threads .messages .create (
106+ thread_id = thread .id ,
107+ role = "user" ,
108+ content = prompt ,
88109 )
89- content = completion .choices [0 ].message .content or ""
90- toolbox .write_file (PurePosixPath (f"{ completion .id } .txt" ), content )
110+ print (message )
111+
112+ with self ._client .beta .threads .runs .stream (
113+ thread_id = thread .id ,
114+ assistant_id = assistant .id ,
115+ event_handler = _EventHandler (self ._client , toolbox ),
116+ ) as stream :
117+ stream .until_done ()
118+
91119 return Session (0 )
120+
121+
122+ class _EventHandler (openai .AssistantEventHandler ):
123+ def __init__ (self , client : openai .Client , toolbox : Toolbox ) -> None :
124+ super ().__init__ ()
125+ self ._client = client
126+ self ._toolbox = toolbox
127+
128+ def clone (self ) -> Self :
129+ return self .__class__ (self ._client , self ._toolbox )
130+
131+ @override
132+ def on_event (self , event : Any ) -> None :
133+ _logger .debug ("Event: %s" , event )
134+ if event .event == "thread.run.requires_action" :
135+ run_id = event .data .id # Retrieve the run ID from the event data
136+ self ._handle_action (run_id , event .data )
137+ # TODO: Handle (log?) other events.
138+
139+ def _handle_action (self , run_id : str , data : Any ) -> None :
140+ tool_outputs = list [Any ]()
141+ for tool in data .required_action .submit_tool_outputs .tool_calls :
142+ name = tool .function .name
143+ inputs = json .loads (tool .function .arguments )
144+ _logger .info ("Requested tool: %s" , tool )
145+ if name == "read_file" :
146+ path = PurePosixPath (inputs ["path" ])
147+ output = self ._toolbox .read_file (path )
148+ elif name == "write_file" :
149+ path = PurePosixPath (inputs ["path" ])
150+ contents = inputs ["contents" ]
151+ self ._toolbox .write_file (path , contents )
152+ output = "OK"
153+ elif name == "list_files" :
154+ assert not inputs
155+ output = "\n " .join (str (p ) for p in self ._toolbox .list_files ())
156+ tool_outputs .append ({"tool_call_id" : tool .id , "output" : output })
157+
158+ run = self .current_run
159+ assert run , "No ongoing run"
160+ with self ._client .beta .threads .runs .submit_tool_outputs_stream (
161+ thread_id = run .thread_id ,
162+ run_id = run .id ,
163+ tool_outputs = tool_outputs ,
164+ event_handler = self .clone (),
165+ ) as stream :
166+ stream .until_done ()
0 commit comments