Skip to content

Commit 1510d85

Browse files
committed
Better tool output for non Anthropic models
1 parent c365866 commit 1510d85

File tree

4 files changed

+127
-94
lines changed

4 files changed

+127
-94
lines changed

interpreter/commands.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def _handle_set_command(self, parts: list[str]) -> bool:
160160
value_str = parts[2]
161161
type_hint, _ = SETTINGS[param]
162162
try:
163+
self.interpreter._client = (
164+
None # Reset client, in case they changed API key or API base
165+
)
163166
value = parse_value(value_str, type_hint)
164167
setattr(self.interpreter, param, value)
165168
print(f"Set {param} = {value}")

interpreter/interpreter.py

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ async def async_respond(self, user_input=None):
249249
provider = self.provider # Keep existing provider if set
250250
max_tokens = self.max_tokens # Keep existing max_tokens if set
251251

252-
if self.model == "claude-3-5-sonnet-latest":
253-
# For some reason, Litellm can't find the model info for claude-3-5-sonnet-latest
252+
if self.model in ["claude-3-5-sonnet-latest", "claude-3-5-sonnet-20241022"]:
253+
# For some reason, Litellm can't find the model info for these
254254
provider = "anthropic"
255255

256256
# Only try to get model info if we need either provider or max_tokens
@@ -294,33 +294,33 @@ async def async_respond(self, user_input=None):
294294

295295
self._spinner.start()
296296

297-
enable_prompt_caching = False
298297
betas = [COMPUTER_USE_BETA_FLAG]
299298

300-
if enable_prompt_caching:
301-
betas.append(PROMPT_CACHING_BETA_FLAG)
302-
image_truncation_threshold = 50
303-
system["cache_control"] = {"type": "ephemeral"}
304-
305299
edit = ToolRenderer()
306300

307301
if (
308302
provider == "anthropic" and not self.serve
309303
): # Server can't handle Anthropic yet
310304
if self._client is None:
311-
if self.api_key:
312-
self._client = Anthropic(api_key=self.api_key)
313-
else:
314-
self._client = Anthropic()
305+
anthropic_params = {}
306+
if self.api_key is not None:
307+
anthropic_params["api_key"] = self.api_key
308+
if self.api_base is not None:
309+
anthropic_params["base_url"] = self.api_base
310+
self._client = Anthropic(**anthropic_params)
315311

316312
if self.debug:
317313
print("Sending messages:", self.messages, "\n")
318314

315+
model = self.model
316+
if model.startswith("anthropic/"):
317+
model = model[len("anthropic/") :]
318+
319319
# Use Anthropic API which supports betas
320320
raw_response = self._client.beta.messages.create(
321321
max_tokens=max_tokens,
322322
messages=self.messages,
323-
model=self.model,
323+
model=model,
324324
system=system["text"],
325325
tools=tool_collection.to_params(),
326326
betas=betas,
@@ -698,7 +698,7 @@ async def async_respond(self, user_input=None):
698698
"temperature": self.temperature,
699699
"api_key": self.api_key,
700700
"api_version": self.api_version,
701-
"parallel_tool_calls": False,
701+
# "parallel_tool_calls": True,
702702
}
703703

704704
if self.tool_calling:
@@ -707,13 +707,32 @@ async def async_respond(self, user_input=None):
707707
params["stream"] = False
708708
stream = False
709709

710-
if self.debug:
711-
print(params)
710+
if provider == "anthropic" and self.tool_calling:
711+
params["tools"] = tool_collection.to_params()
712+
for t in params["tools"]:
713+
t["function"] = {"name": t["name"]}
714+
if t["name"] == "computer":
715+
t["function"]["parameters"] = {
716+
"display_height_px": t["display_height_px"],
717+
"display_width_px": t["display_width_px"],
718+
"display_number": t["display_number"],
719+
}
720+
params["extra_headers"] = {
721+
"anthropic-beta": "computer-use-2024-10-22"
722+
}
712723

713-
if self.debug:
714-
print("Sending request...", params)
724+
# if self.debug:
725+
# print("Sending request...", params)
726+
# time.sleep(3)
715727

716-
time.sleep(3)
728+
if self.debug:
729+
print("Messages:")
730+
for m in self.messages:
731+
if len(str(m)) > 1000:
732+
print(str(m)[:1000] + "...")
733+
else:
734+
print(str(m))
735+
print()
717736

718737
raw_response = litellm.completion(**params)
719738

@@ -856,6 +875,8 @@ async def async_respond(self, user_input=None):
856875
else:
857876
user_approval = input("\nRun tool(s)? (y/n): ").lower().strip()
858877

878+
user_content_to_add = []
879+
859880
for tool_call in message.tool_calls:
860881
function_arguments = json.loads(tool_call.function.arguments)
861882

@@ -869,43 +890,46 @@ async def async_respond(self, user_input=None):
869890

870891
if self.tool_calling:
871892
if result.base64_image:
872-
# Add image to tool result
873893
self.messages.append(
874894
{
875895
"role": "tool",
876-
"content": "The user will reply with the image outputted by the tool.",
896+
"content": "The user will reply with the tool's image output.",
877897
"tool_call_id": tool_call.id,
878898
}
879899
)
880-
self.messages.append(
900+
user_content_to_add.append(
881901
{
882-
"role": "user",
883-
"content": [
884-
{
885-
"type": "image_url",
886-
"image_url": {
887-
"url": f"data:image/png;base64,{result.base64_image}",
888-
},
889-
}
890-
],
891-
}
892-
)
893-
else:
894-
self.messages.append(
895-
{
896-
"role": "tool",
897-
"content": json.dumps(dataclasses.asdict(result)),
898-
"tool_call_id": tool_call.id,
902+
"type": "image_url",
903+
"image_url": {
904+
"url": f"data:image/png;base64,{result.base64_image}",
905+
},
899906
}
900907
)
901908
else:
902-
self.messages.append(
903-
{
904-
"role": "user",
905-
"content": "This was the output of the tool call. What does it mean/what's next?"
906-
+ json.dumps(dataclasses.asdict(result)),
907-
}
909+
text_content = (
910+
"This was the output of the tool call. What does it mean/what's next?\n"
911+
+ (result.output or "")
908912
)
913+
if result.base64_image:
914+
content = [
915+
{"type": "text", "text": text_content},
916+
{
917+
"type": "image",
918+
"image_url": {
919+
"url": "data:image/png;base64,"
920+
+ result.base64_image
921+
},
922+
},
923+
]
924+
else:
925+
content = text_content
926+
927+
self.messages.append({"role": "user", "content": content})
928+
929+
if user_content_to_add:
930+
self.messages.append(
931+
{"role": "user", "content": user_content_to_add}
932+
)
909933

910934
def _ask_user_approval(self) -> str:
911935
"""Ask user for approval to run a tool"""

interpreter/misc/get_input.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ async def async_get_input(
1111
placeholder_color: str = "gray",
1212
multiline_support: bool = True,
1313
) -> str:
14-
placeholder_text = "Describe command"
14+
# placeholder_text = "Describe command"
15+
placeholder_text = 'Use """ for multi-line input'
1516
history = InMemoryHistory()
1617
session = PromptSession(
1718
history=history,
@@ -27,11 +28,16 @@ async def async_get_input(
2728
def _(event):
2829
current_line = event.current_buffer.document.current_line.rstrip()
2930

30-
if current_line == '"""':
31-
multiline[0] = not multiline[0]
31+
if not multiline[0] and current_line.endswith('"""'):
32+
# Enter multiline mode
33+
multiline[0] = True
3234
event.current_buffer.insert_text("\n")
33-
if not multiline[0]: # If exiting multiline mode, submit
34-
event.current_buffer.validate_and_handle()
35+
return
36+
37+
if multiline[0] and current_line.startswith('"""'):
38+
# Exit multiline mode and submit
39+
multiline[0] = False
40+
event.current_buffer.validate_and_handle()
3541
return
3642

3743
if multiline[0]:
@@ -55,50 +61,50 @@ def _(event):
5561
return result
5662

5763

58-
def get_input(
59-
placeholder_text: Optional[str] = None,
60-
placeholder_color: str = "gray",
61-
multiline_support: bool = True,
62-
) -> str:
63-
placeholder_text = "Describe command"
64-
history = InMemoryHistory()
65-
session = PromptSession(
66-
history=history,
67-
enable_open_in_editor=False,
68-
enable_history_search=False,
69-
auto_suggest=None,
70-
multiline=True,
71-
)
72-
kb = KeyBindings()
73-
multiline = [False]
64+
# def get_input(
65+
# placeholder_text: Optional[str] = None,
66+
# placeholder_color: str = "gray",
67+
# multiline_support: bool = True,
68+
# ) -> str:
69+
# placeholder_text = "Describe command"
70+
# history = InMemoryHistory()
71+
# session = PromptSession(
72+
# history=history,
73+
# enable_open_in_editor=False,
74+
# enable_history_search=False,
75+
# auto_suggest=None,
76+
# multiline=True,
77+
# )
78+
# kb = KeyBindings()
79+
# multiline = [False]
7480

75-
@kb.add("enter")
76-
def _(event):
77-
current_line = event.current_buffer.document.current_line.rstrip()
81+
# @kb.add("enter")
82+
# def _(event):
83+
# current_line = event.current_buffer.document.current_line.rstrip()
7884

79-
if current_line == '"""':
80-
multiline[0] = not multiline[0]
81-
event.current_buffer.insert_text("\n")
82-
if not multiline[0]: # If exiting multiline mode, submit
83-
event.current_buffer.validate_and_handle()
84-
return
85+
# if current_line == '"""':
86+
# multiline[0] = not multiline[0]
87+
# event.current_buffer.insert_text("\n")
88+
# if not multiline[0]: # If exiting multiline mode, submit
89+
# event.current_buffer.validate_and_handle()
90+
# return
8591

86-
if multiline[0]:
87-
event.current_buffer.insert_text("\n")
88-
else:
89-
event.current_buffer.validate_and_handle()
92+
# if multiline[0]:
93+
# event.current_buffer.insert_text("\n")
94+
# else:
95+
# event.current_buffer.validate_and_handle()
9096

91-
result = session.prompt(
92-
"> ",
93-
placeholder=HTML(f'<style fg="{placeholder_color}">{placeholder_text}</style>')
94-
if placeholder_text
95-
else None,
96-
key_bindings=kb,
97-
complete_while_typing=False,
98-
enable_suspend=False,
99-
search_ignore_case=True,
100-
include_default_pygments_style=False,
101-
input_processors=[],
102-
enable_system_prompt=False,
103-
)
104-
return result
97+
# result = session.prompt(
98+
# "> ",
99+
# placeholder=HTML(f'<style fg="{placeholder_color}">{placeholder_text}</style>')
100+
# if placeholder_text
101+
# else None,
102+
# key_bindings=kb,
103+
# complete_while_typing=False,
104+
# enable_suspend=False,
105+
# search_ignore_case=True,
106+
# include_default_pygments_style=False,
107+
# input_processors=[],
108+
# enable_system_prompt=False,
109+
# )
110+
# return result

interpreter/profiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Profile:
3232
def __init__(self):
3333
# Default values if no profile exists
3434
# Model configuration
35-
self.model = "claude-3-5-sonnet-latest" # The LLM model to use
35+
self.model = "claude-3-5-sonnet-20241022" # The LLM model to use
3636
self.provider = (
3737
None # The model provider (e.g. anthropic, openai) None will auto-detect
3838
)

0 commit comments

Comments
 (0)