|
1 | 1 | import asyncio
|
| 2 | +import dataclasses |
2 | 3 | import json
|
3 | 4 | import os
|
4 | 5 | import platform
|
|
23 | 24 | from urllib.parse import quote
|
24 | 25 |
|
25 | 26 | import litellm
|
26 |
| -import pyautogui |
27 | 27 | from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
28 | 28 | from anthropic.types.beta import (
|
29 |
| - BetaCacheControlEphemeralParam, |
30 | 29 | BetaContentBlock,
|
31 | 30 | BetaContentBlockParam,
|
32 | 31 | BetaImageBlockParam,
|
@@ -223,7 +222,8 @@ async def async_respond(self):
|
223 | 222 | )
|
224 | 223 |
|
225 | 224 | model_info = litellm.get_model_info(self.model)
|
226 |
| - provider = model_info["litellm_provider"] |
| 225 | + if self.provider == None: |
| 226 | + self.provider = model_info["litellm_provider"] |
227 | 227 | max_tokens = model_info["max_tokens"]
|
228 | 228 |
|
229 | 229 | while True:
|
@@ -449,9 +449,152 @@ async def async_respond(self):
|
449 | 449 | )
|
450 | 450 |
|
451 | 451 | else:
|
452 |
| - # LiteLLM implementation would go here |
453 |
| - # (I can add this if you'd like, but focusing on the Anthropic path for now) |
454 |
| - pass |
| 452 | + tools = [] |
| 453 | + if "interpreter" in self.tools: |
| 454 | + tools.append( |
| 455 | + { |
| 456 | + "type": "function", |
| 457 | + "function": { |
| 458 | + "name": "bash", |
| 459 | + "description": """Run commands in a bash shell\n |
| 460 | + * When invoking this tool, the contents of the \"command\" parameter does NOT need to be XML-escaped.\n |
| 461 | + * You don't have access to the internet via this tool.\n |
| 462 | + * You do have access to a mirror of common linux and python packages via apt and pip.\n |
| 463 | + * State is persistent across command calls and discussions with the user.\n |
| 464 | + * To inspect a particular line range of a file, e.g. lines 10-25, try 'sed -n 10,25p /path/to/the/file'.\n |
| 465 | + * Please avoid commands that may produce a very large amount of output.\n |
| 466 | + * Please run long lived commands in the background, e.g. 'sleep 10 &' or start a server in the background.""", |
| 467 | + "parameters": { |
| 468 | + "type": "object", |
| 469 | + "properties": { |
| 470 | + "command": { |
| 471 | + "type": "string", |
| 472 | + "description": "The bash command to run.", |
| 473 | + } |
| 474 | + }, |
| 475 | + "required": ["command"], |
| 476 | + }, |
| 477 | + }, |
| 478 | + } |
| 479 | + ) |
| 480 | + |
| 481 | + if self.model.startswith("ollama/"): |
| 482 | + # Fix ollama |
| 483 | + stream = False |
| 484 | + actual_model = self.model.replace("ollama/", "openai/") |
| 485 | + if self.api_base == None: |
| 486 | + api_base = "http://localhost:11434/v1/" |
| 487 | + else: |
| 488 | + api_base = self.api_base |
| 489 | + else: |
| 490 | + stream = True |
| 491 | + api_base = self.api_base |
| 492 | + actual_model = self.model |
| 493 | + |
| 494 | + params = { |
| 495 | + "model": actual_model, |
| 496 | + "messages": [{"role": "system", "content": SYSTEM_PROMPT}] |
| 497 | + + self.messages, |
| 498 | + "stream": stream, |
| 499 | + "api_base": api_base, |
| 500 | + "temperature": self.temperature, |
| 501 | + "tools": tools, |
| 502 | + } |
| 503 | + |
| 504 | + raw_response = litellm.completion(**params) |
| 505 | + |
| 506 | + if not stream: |
| 507 | + raw_response.choices[0].delta = raw_response.choices[0].message |
| 508 | + raw_response = [raw_response] |
| 509 | + |
| 510 | + message = None |
| 511 | + first_token = True |
| 512 | + |
| 513 | + for chunk in raw_response: |
| 514 | + if first_token: |
| 515 | + self._spinner.stop() |
| 516 | + first_token = False |
| 517 | + |
| 518 | + if message == None: |
| 519 | + message = chunk.choices[0].delta |
| 520 | + |
| 521 | + if chunk.choices[0].delta.content: |
| 522 | + yield {"type": "chunk", "chunk": chunk.choices[0].delta.content} |
| 523 | + md.feed(chunk.choices[0].delta.content) |
| 524 | + await asyncio.sleep(0) |
| 525 | + |
| 526 | + if chunk.choices[0].delta != message: |
| 527 | + message.content += chunk.choices[0].delta.content |
| 528 | + |
| 529 | + if chunk.choices[0].delta.tool_calls: |
| 530 | + if chunk.choices[0].delta.tool_calls[0].id: |
| 531 | + if message.tool_calls == None or chunk.choices[ |
| 532 | + 0 |
| 533 | + ].delta.tool_calls[0].id not in [ |
| 534 | + t.id for t in message.tool_calls |
| 535 | + ]: |
| 536 | + edit.close() |
| 537 | + edit = ToolRenderer() |
| 538 | + if message.tool_calls == None: |
| 539 | + message.tool_calls = [] |
| 540 | + message.tool_calls.append( |
| 541 | + chunk.choices[0].delta.tool_calls[0] |
| 542 | + ) |
| 543 | + current_tool_call = [ |
| 544 | + t |
| 545 | + for t in message.tool_calls |
| 546 | + if t.id == chunk.choices[0].delta.tool_calls[0].id |
| 547 | + ][0] |
| 548 | + |
| 549 | + if chunk.choices[0].delta.tool_calls[0].function.name: |
| 550 | + tool_name = ( |
| 551 | + chunk.choices[0].delta.tool_calls[0].function.name |
| 552 | + ) |
| 553 | + if edit.name == None: |
| 554 | + edit.name = tool_name |
| 555 | + if current_tool_call.function.name == None: |
| 556 | + current_tool_call.function.name = tool_name |
| 557 | + if chunk.choices[0].delta.tool_calls[0].function.arguments: |
| 558 | + arguments_delta = ( |
| 559 | + chunk.choices[0].delta.tool_calls[0].function.arguments |
| 560 | + ) |
| 561 | + edit.feed(arguments_delta) |
| 562 | + |
| 563 | + if chunk.choices[0].delta != message: |
| 564 | + current_tool_call.function.arguments += arguments_delta |
| 565 | + |
| 566 | + if chunk.choices[0].finish_reason: |
| 567 | + edit.close() |
| 568 | + edit = ToolRenderer() |
| 569 | + |
| 570 | + self.messages.append(message) |
| 571 | + |
| 572 | + print() |
| 573 | + |
| 574 | + if not message.tool_calls: |
| 575 | + yield {"type": "messages", "messages": self.messages} |
| 576 | + break |
| 577 | + |
| 578 | + user_approval = input("\nRun tool(s)? (y/n): ").lower().strip() |
| 579 | + |
| 580 | + for tool_call in message.tool_calls: |
| 581 | + function_arguments = json.loads(tool_call.function.arguments) |
| 582 | + |
| 583 | + if user_approval == "y": |
| 584 | + result = await tool_collection.run( |
| 585 | + name=tool_call.function.name, |
| 586 | + tool_input=cast(dict[str, Any], function_arguments), |
| 587 | + ) |
| 588 | + else: |
| 589 | + result = ToolResult(output="Tool execution cancelled by user") |
| 590 | + |
| 591 | + self.messages.append( |
| 592 | + { |
| 593 | + "role": "tool", |
| 594 | + "content": json.dumps(dataclasses.asdict(result)), |
| 595 | + "tool_call_id": tool_call.id, |
| 596 | + } |
| 597 | + ) |
455 | 598 |
|
456 | 599 | def _ask_user_approval(self) -> str:
|
457 | 600 | """Ask user for approval to run a tool"""
|
|
0 commit comments