Skip to content

Commit 9295ca9

Browse files
author
ochafik
committed
tool-call: fix agent type lints
1 parent 1e5c0e7 commit 9295ca9

File tree

4 files changed

+24
-25
lines changed

4 files changed

+24
-25
lines changed

examples/tool-call/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@
3030
uv run examples/tool-call/agent.py \
3131
--tool-endpoint http://localhost:8088 \
3232
--goal "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?"
33-
```
33+
```

examples/tool-call/agent.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# ///
1212
import json
1313
import openai
14+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
1415
from pydantic import BaseModel
1516
import requests
1617
import sys
1718
import typer
18-
from typing import Annotated, List, Optional
19-
import urllib
19+
from typing import Annotated, Optional
20+
import urllib.parse
2021

2122

2223
class OpenAPIMethod:
@@ -94,24 +95,24 @@ def __call__(self, **kwargs):
9495
def main(
9596
goal: Annotated[str, typer.Option()],
9697
api_key: Optional[str] = None,
97-
tool_endpoint: Optional[List[str]] = None,
98+
tool_endpoint: Optional[list[str]] = None,
9899
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
99100
max_iterations: Optional[int] = 10,
100101
parallel_calls: Optional[bool] = False,
101102
verbose: bool = False,
102103
# endpoint: Optional[str] = None,
103104
endpoint: str = "http://localhost:8080/v1/",
104105
):
105-
106+
106107
openai.api_key = api_key
107108
openai.base_url = endpoint
108-
109+
109110
tool_map = {}
110111
tools = []
111-
112+
112113
for url in (tool_endpoint or []):
113114
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
114-
115+
115116
catalog_url = f'{url}/openapi.json'
116117
catalog_response = requests.get(catalog_url)
117118
catalog_response.raise_for_status()
@@ -131,19 +132,19 @@ def main(
131132
)
132133
)
133134
)
134-
135+
135136
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
136137

137-
messages = [
138-
dict(
138+
messages: list[ChatCompletionMessageParam] = [
139+
ChatCompletionUserMessageParam(
139140
role="user",
140141
content=goal,
141142
)
142143
]
143144

144145
i = 0
145146
while (max_iterations is None or i < max_iterations):
146-
147+
147148
response = openai.chat.completions.create(
148149
model="gpt-4o",
149150
messages=messages,
@@ -152,13 +153,14 @@ def main(
152153

153154
if verbose:
154155
sys.stderr.write(f'# RESPONSE: {response}\n')
155-
156+
156157
assert len(response.choices) == 1
157158
choice = response.choices[0]
158159

159160
content = choice.message.content
160161
if choice.finish_reason == "tool_calls":
161-
messages.append(choice.message)
162+
messages.append(choice.message) # type: ignore
163+
assert choice.message.tool_calls
162164
for tool_call in choice.message.tool_calls:
163165
if content:
164166
print(f'💭 {content}')
@@ -169,11 +171,11 @@ def main(
169171
sys.stdout.flush()
170172
tool_result = tool_map[tool_call.function.name](**args)
171173
sys.stdout.write(f" → {tool_result}\n")
172-
messages.append(dict(
174+
messages.append(ChatCompletionToolMessageParam(
173175
tool_call_id=tool_call.id,
174176
role="tool",
175-
name=tool_call.function.name,
176-
content=f'{tool_result}',
177+
# name=tool_call.function.name,
178+
content=json.dumps(tool_result),
177179
# content=f'{pretty_call} = {tool_result}',
178180
))
179181
else:

examples/tool-call/fastify.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
5959
continue
6060

6161
vt = type(v)
62-
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(v.func):
63-
v = v.func
62+
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(func := getattr(v, 'func')):
63+
v = func
6464

6565
print(f'INFO: Binding /{k}')
6666
try:
@@ -73,4 +73,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
7373

7474

7575
if __name__ == '__main__':
76-
typer.run(main)
76+
typer.run(main)

examples/tool-call/tools.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
from datetime import date
21
import datetime
32
import json
43
from pydantic import BaseModel
5-
import subprocess
64
import sys
75
import time
8-
import typer
9-
from typing import Union, Optional, Dict
106
import types
7+
from typing import Union, Optional, Dict
118

129

1310
class Duration(BaseModel):
@@ -58,7 +55,7 @@ def wait_for_duration(duration: Duration) -> None:
5855
time.sleep(duration.get_total_seconds)
5956

6057
@staticmethod
61-
def wait_for_date(target_date: date) -> None:
58+
def wait_for_date(target_date: datetime.date) -> None:
6259
f'''
6360
Wait until a specific date is reached before continuing.
6461
Today's date is {datetime.date.today()}

0 commit comments

Comments
 (0)