Skip to content

Commit d0a59e2

Browse files
committed
Server improvements, common hallucination fix
1 parent da057d4 commit d0a59e2

File tree

7 files changed

+293
-18
lines changed

7 files changed

+293
-18
lines changed

benchmarks/run.py

Whitespace-only changes.

interpreter/core/async_core.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs):
2525
self.stop_event = threading.Event()
2626
self.output_queue = None
2727
self.id = os.getenv("INTERPRETER_ID", datetime.now().timestamp())
28+
self.print = True # Will print output
2829

2930
self.server = Server(self)
3031

@@ -45,41 +46,60 @@ async def input(self, chunk):
4546
elif "end" in chunk:
4647
# If the user is done talking, the interpreter should respond.
4748

48-
# But first, process any client messages.
49-
if self.messages[-1]["role"] == "client":
49+
run_code = None # Will later default to auto_run unless the user makes a command here
50+
51+
# But first, process any commands.
52+
if self.messages[-1]["type"] == "command":
5053
command = self.messages[-1]["content"]
5154
self.messages = self.messages[:-1]
5255

5356
if command == "stop":
57+
# Any start flag would have stopped it a moment ago, but to be sure:
58+
self.stop_event.set()
59+
self.respond_thread.join()
5460
return
5561
if command == "go":
5662
# This is to approve code.
57-
# We do nothing, as self.respond will run the last code block if the last message is one.
63+
run_code = True
5864
pass
5965

6066
self.stop_event.clear()
61-
self.respond_thread = threading.Thread(target=self.respond)
67+
self.respond_thread = threading.Thread(
68+
target=self.respond, args=(run_code,)
69+
)
6270
self.respond_thread.start()
6371

6472
async def output(self):
6573
if self.output_queue == None:
6674
self.output_queue = janus.Queue()
6775
return await self.output_queue.async_q.get()
6876

69-
def respond(self):
77+
def respond(self, run_code=None):
78+
if run_code == None:
79+
run_code = self.auto_run
80+
7081
for chunk in self._respond_and_store():
71-
if chunk["type"] in ["code", "output"]:
72-
if "start" in chunk:
73-
print("\n\n```" + chunk["format"], flush=True)
74-
if "end" in chunk:
75-
print("\n```", flush=True)
76-
print(chunk.get("content", ""), end="", flush=True)
82+
if chunk["type"] == "confirmation":
83+
if run_code:
84+
continue # We don't need to send out confirmation chunks on the server. I don't even like them.
85+
else:
86+
break
87+
7788
if self.stop_event.is_set():
7889
return
90+
91+
if self.print:
92+
if chunk["type"] in ["code", "output"]:
93+
if "start" in chunk:
94+
print("\n\n```" + chunk["format"], flush=True)
95+
if "end" in chunk:
96+
print("\n```", flush=True)
97+
print(chunk.get("content", ""), end="", flush=True)
98+
7999
self.output_queue.sync_q.put(chunk)
80100

81101
self.output_queue.sync_q.put(
82-
{"role": "server", "type": "status", "content": "complete"}
102+
{"role": "assistant", "type": "status", "content": "complete"}
83103
)
84104

85105
def accumulate(self, chunk):
@@ -202,7 +222,7 @@ async def set_settings(payload: Dict[str, Any]):
202222

203223
return {"status": "success"}
204224

205-
@router.get("/interpreter/{setting}")
225+
@router.get("/settings/{setting}")
206226
async def get_setting(setting: str):
207227
if hasattr(async_interpreter, setting):
208228
setting_value = getattr(async_interpreter, setting)

interpreter/core/llm/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
display_markdown_message,
1717
)
1818
from .run_function_calling_llm import run_function_calling_llm
19+
20+
# from .run_tool_calling_llm import run_tool_calling_llm
1921
from .run_text_llm import run_text_llm
2022
from .utils.convert_to_openai_messages import convert_to_openai_messages
2123

@@ -283,6 +285,7 @@ def run(self, messages):
283285

284286
if self.supports_functions:
285287
yield from run_function_calling_llm(self, params)
288+
# yield from run_tool_calling_llm(self, params)
286289
else:
287290
yield from run_text_llm(self, params)
288291

interpreter/core/llm/run_function_calling_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def run_function_calling_llm(llm, request_params):
3131
request_params["functions"] = [function_schema]
3232

3333
# Add OpenAI's recommended function message
34-
request_params["messages"][0][
35-
"content"
36-
] += "\nUse ONLY the function you have been provided with — 'execute(language, code)'."
34+
# request_params["messages"][0][
35+
# "content"
36+
# ] += "\nUse ONLY the function you have been provided with — 'execute(language, code)'."
3737

3838
## Convert output to LMC format
3939

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from .utils.merge_deltas import merge_deltas
2+
from .utils.parse_partial_json import parse_partial_json
3+
4+
tool_schema = {
5+
"type": "function",
6+
"function": {
7+
"name": "execute",
8+
"description": "Executes code on the user's machine **in the users local environment** and returns the output",
9+
"parameters": {
10+
"type": "object",
11+
"properties": {
12+
"language": {
13+
"type": "string",
14+
"description": "The programming language (required parameter to the `execute` function)",
15+
"enum": [
16+
# This will be filled dynamically with the languages OI has access to.
17+
],
18+
},
19+
"code": {
20+
"type": "string",
21+
"description": "The code to execute (required)",
22+
},
23+
},
24+
"required": ["language", "code"],
25+
},
26+
},
27+
}
28+
29+
30+
def run_tool_calling_llm(llm, request_params):
31+
## Setup
32+
33+
# Add languages OI has access to
34+
tool_schema["function"]["parameters"]["properties"]["language"]["enum"] = [
35+
i.name.lower() for i in llm.interpreter.computer.terminal.languages
36+
]
37+
request_params["tools"] = [tool_schema]
38+
39+
# Add OpenAI's recommended function message
40+
# request_params["messages"][0][
41+
# "content"
42+
# ] += "\nUse ONLY the function you have been provided with — 'execute(language, code)'."
43+
44+
## Convert output to LMC format
45+
46+
accumulated_deltas = {}
47+
language = None
48+
code = ""
49+
50+
for chunk in llm.completions(**request_params):
51+
if "choices" not in chunk or len(chunk["choices"]) == 0:
52+
# This happens sometimes
53+
continue
54+
55+
delta = chunk["choices"][0]["delta"]
56+
57+
# Convert tool call into function call, which we have great parsing logic for below
58+
if "tool_calls" in delta:
59+
if (
60+
len(delta["tool_calls"]) > 0
61+
and "function_call" in delta["tool_calls"][0]
62+
):
63+
delta["function_call"] = delta["tool_calls"][0]["function_call"]
64+
65+
# Accumulate deltas
66+
accumulated_deltas = merge_deltas(accumulated_deltas, delta)
67+
68+
if "content" in delta and delta["content"]:
69+
yield {"type": "message", "content": delta["content"]}
70+
71+
if (
72+
accumulated_deltas.get("function_call")
73+
and "arguments" in accumulated_deltas["function_call"]
74+
and accumulated_deltas["function_call"]["arguments"]
75+
):
76+
if (
77+
"name" in accumulated_deltas["function_call"]
78+
and accumulated_deltas["function_call"]["name"] == "execute"
79+
):
80+
arguments = accumulated_deltas["function_call"]["arguments"]
81+
arguments = parse_partial_json(arguments)
82+
83+
if arguments:
84+
if (
85+
language is None
86+
and "language" in arguments
87+
and "code"
88+
in arguments # <- This ensures we're *finished* typing language, as opposed to partially done
89+
and arguments["language"]
90+
):
91+
language = arguments["language"]
92+
93+
if language is not None and "code" in arguments:
94+
# Calculate the delta (new characters only)
95+
code_delta = arguments["code"][len(code) :]
96+
# Update the code
97+
code = arguments["code"]
98+
# Yield the delta
99+
if code_delta:
100+
yield {
101+
"type": "code",
102+
"format": language,
103+
"content": code_delta,
104+
}
105+
else:
106+
if llm.interpreter.verbose:
107+
print("Arguments not a dict.")
108+
109+
# Common hallucinations
110+
elif "name" in accumulated_deltas["function_call"] and (
111+
accumulated_deltas["function_call"]["name"] == "python"
112+
or accumulated_deltas["function_call"]["name"] == "functions"
113+
):
114+
if llm.interpreter.verbose:
115+
print("Got direct python call")
116+
if language is None:
117+
language = "python"
118+
119+
if language is not None:
120+
# Pull the code string straight out of the "arguments" string
121+
code_delta = accumulated_deltas["function_call"]["arguments"][
122+
len(code) :
123+
]
124+
# Update the code
125+
code = accumulated_deltas["function_call"]["arguments"]
126+
# Yield the delta
127+
if code_delta:
128+
yield {
129+
"type": "code",
130+
"format": language,
131+
"content": code_delta,
132+
}
133+
134+
else:
135+
# If name exists and it's not "execute" or "python" or "functions", who knows what's going on.
136+
if "name" in accumulated_deltas["function_call"]:
137+
yield {
138+
"type": "code",
139+
"format": "python",
140+
"content": accumulated_deltas["function_call"]["name"],
141+
}
142+
return

interpreter/core/respond.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,18 @@ def respond(interpreter):
146146
code = code[2:].strip()
147147
if interpreter.verbose:
148148
print("Removing `\n")
149+
interpreter.messages[-1]["content"] = code # So the LLM can see it.
150+
151+
# A common hallucination
152+
if code.startswith("functions.execute("):
153+
code = code.replace("functions.execute(", "").rstrip(")")
154+
code_dict = json.loads(code)
155+
language = code_dict.get("language", language)
156+
code = code_dict.get("code", code)
157+
interpreter.messages[-1]["content"] = code # So the LLM can see it.
158+
interpreter.messages[-1][
159+
"format"
160+
] = language # So the LLM can see it.
149161

150162
if language == "text" or language == "markdown":
151163
# It does this sometimes just to take notes. Let it, it's useful.

0 commit comments

Comments
 (0)