Skip to content

Commit 2aa6773

Browse files
committed
Use Pydantic AI properly
1 parent ab21706 commit 2aa6773

16 files changed

+280
-144
lines changed

examples/agentframework_basic.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,23 @@
1515
if API_HOST == "azure":
1616
async_credential = DefaultAzureCredential()
1717
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
18-
client = OpenAIChatClient(base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/", api_key=token_provider, model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"])
18+
client = OpenAIChatClient(
19+
base_url=f"{os.environ['AZURE_OPENAI_ENDPOINT']}/openai/v1/",
20+
api_key=token_provider,
21+
model_id=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
22+
)
1923
elif API_HOST == "github":
20-
client = OpenAIChatClient(base_url="https://models.github.ai/inference", api_key=os.environ["GITHUB_TOKEN"], model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4o"))
24+
client = OpenAIChatClient(
25+
base_url="https://models.github.ai/inference",
26+
api_key=os.environ["GITHUB_TOKEN"],
27+
model_id=os.getenv("GITHUB_MODEL", "openai/gpt-4o"),
28+
)
2129
elif API_HOST == "ollama":
22-
client = OpenAIChatClient(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none", model_id=os.environ.get("OLLAMA_MODEL", "llama3.1:latest"))
30+
client = OpenAIChatClient(
31+
base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"),
32+
api_key="none",
33+
model_id=os.environ.get("OLLAMA_MODEL", "llama3.1:latest"),
34+
)
2335
else:
2436
client = OpenAIChatClient(api_key=os.environ["OPENAI_API_KEY"], model_id=os.environ.get("OPENAI_MODEL", "gpt-4o"))
2537

examples/pydanticai_basic.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import asyncio
22
import os
33

4-
from azure.identity import DefaultAzureCredential
5-
from azure.identity.aio import get_bearer_token_provider
4+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
65
from dotenv import load_dotenv
76
from openai import AsyncOpenAI
87
from pydantic_ai import Agent
@@ -13,19 +12,24 @@
1312
load_dotenv(override=True)
1413
API_HOST = os.getenv("API_HOST", "github")
1514

16-
if API_HOST == "github":
17-
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
18-
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
19-
elif API_HOST == "azure":
20-
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
15+
async_credential = None
16+
if API_HOST == "azure":
17+
async_credential = DefaultAzureCredential()
18+
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
2119
client = AsyncOpenAI(
2220
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
2321
api_key=token_provider,
2422
)
2523
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
24+
elif API_HOST == "github":
25+
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
26+
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
2627
elif API_HOST == "ollama":
2728
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
2829
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
30+
else:
31+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
32+
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
2933

3034
agent: Agent[None, str] = Agent(
3135
model,
@@ -38,6 +42,9 @@ async def main():
3842
result = await agent.run("oh hey how are you?")
3943
print(result.output)
4044

45+
if async_credential:
46+
await async_credential.close()
47+
4148

4249
if __name__ == "__main__":
4350
asyncio.run(main())

examples/pydanticai_graph.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import os
55
from dataclasses import dataclass, field
66

7-
from azure.identity import DefaultAzureCredential
8-
from azure.identity.aio import get_bearer_token_provider
7+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
98
from dotenv import load_dotenv
109
from groq import BaseModel
1110
from openai import AsyncOpenAI
@@ -24,19 +23,24 @@
2423
load_dotenv(override=True)
2524
API_HOST = os.getenv("API_HOST", "github")
2625

27-
if API_HOST == "github":
28-
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
29-
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
30-
elif API_HOST == "azure":
31-
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
26+
async_credential = None
27+
if API_HOST == "azure":
28+
async_credential = DefaultAzureCredential()
29+
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
3230
client = AsyncOpenAI(
3331
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
3432
api_key=token_provider,
3533
)
3634
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
35+
elif API_HOST == "github":
36+
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
37+
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
3738
elif API_HOST == "ollama":
3839
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
3940
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
41+
else:
42+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
43+
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
4044

4145
"""
4246
Agent definitions
@@ -130,6 +134,9 @@ async def main():
130134
end = await question_graph.run(node, state=state)
131135
print("END:", end.output)
132136

137+
if async_credential:
138+
await async_credential.close()
139+
133140

134141
if __name__ == "__main__":
135142
asyncio.run(main())

examples/pydanticai_mcp_github.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import logging
1919
import os
2020

21-
from azure.identity import DefaultAzureCredential
22-
from azure.identity.aio import get_bearer_token_provider
21+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
2322
from dotenv import load_dotenv
2423
from openai import AsyncOpenAI
2524
from pydantic import BaseModel, Field
@@ -40,26 +39,24 @@
4039
load_dotenv(override=True)
4140
API_HOST = os.getenv("API_HOST", "github")
4241

43-
42+
async_credential = None
4443
if API_HOST == "azure":
45-
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
44+
async_credential = DefaultAzureCredential()
45+
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
4646
client = AsyncOpenAI(
4747
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
4848
api_key=token_provider,
4949
)
50-
model = OpenAIChatModel(
51-
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
52-
provider=OpenAIProvider(openai_client=client),
53-
)
50+
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
5451
elif API_HOST == "github":
5552
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
56-
model = OpenAIChatModel(os.environ.get("GITHUB_MODEL", "gpt-4o-mini"), provider=OpenAIProvider(openai_client=client))
53+
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
5754
elif API_HOST == "ollama":
58-
client = AsyncOpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="none")
55+
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
5956
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
6057
else:
61-
client = AsyncOpenAI()
62-
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o-mini"), provider=OpenAIProvider(openai_client=client))
58+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
59+
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
6360

6461

6562
class IssueProposal(BaseModel):
@@ -73,13 +70,18 @@ class IssueProposal(BaseModel):
7370

7471

7572
async def main():
76-
server = MCPServerStreamableHTTP(url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {os.getenv('GITHUB_TOKEN', '')}"})
73+
server = MCPServerStreamableHTTP(
74+
url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {os.getenv('GITHUB_TOKEN', '')}"}
75+
)
7776
desired_tool_names = ("list_issues", "search_code", "search_issues", "search_pull_requests")
7877
filtered_tools = server.filtered(lambda ctx, tool_def: tool_def.name in desired_tool_names)
7978

8079
agent: Agent[None, IssueProposal] = Agent(
8180
model,
82-
system_prompt=("You are an issue triage assistant. Use the provided tools to find an issue that can be closed " "and produce an IssueProposal."),
81+
system_prompt=(
82+
"You are an issue triage assistant. Use the provided tools to find an issue that can be closed "
83+
"and produce an IssueProposal."
84+
),
8385
output_type=IssueProposal,
8486
toolsets=[filtered_tools],
8587
)
@@ -96,6 +98,9 @@ async def main():
9698

9799
print(agent_run.result.output)
98100

101+
if async_credential:
102+
await async_credential.close()
103+
99104

100105
if __name__ == "__main__":
101106
logger.setLevel(logging.INFO)

examples/pydanticai_mcp_http.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import logging
1010
import os
1111

12-
from azure.identity import DefaultAzureCredential
13-
from azure.identity.aio import get_bearer_token_provider
12+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
1413
from dotenv import load_dotenv
1514
from openai import AsyncOpenAI
1615
from pydantic_ai import Agent
@@ -22,19 +21,24 @@
2221
load_dotenv(override=True)
2322
API_HOST = os.getenv("API_HOST", "github")
2423

25-
if API_HOST == "github":
26-
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
27-
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
28-
elif API_HOST == "azure":
29-
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
24+
async_credential = None
25+
if API_HOST == "azure":
26+
async_credential = DefaultAzureCredential()
27+
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
3028
client = AsyncOpenAI(
3129
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
3230
api_key=token_provider,
3331
)
3432
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
33+
elif API_HOST == "github":
34+
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
35+
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
3536
elif API_HOST == "ollama":
3637
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
3738
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
39+
else:
40+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
41+
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
3842

3943
server = MCPServerStreamableHTTP(url="http://localhost:8000/mcp")
4044

@@ -47,9 +51,14 @@
4751

4852

4953
async def main():
50-
result = await agent.run("Find me a hotel in San Francisco for 2 nights starting from 2024-01-01. I need a hotel with free WiFi and a pool.")
54+
result = await agent.run(
55+
"Find me a hotel in San Francisco for 2 nights starting from 2024-01-01. I need free WiFi and a pool."
56+
)
5157
print(result.output)
5258

59+
if async_credential:
60+
await async_credential.close()
61+
5362

5463
if __name__ == "__main__":
5564
logging.basicConfig(level=logging.WARNING)

examples/pydanticai_multiagent.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import os
33
from typing import Literal
44

5-
from azure.identity import DefaultAzureCredential
6-
from azure.identity.aio import get_bearer_token_provider
5+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
76
from dotenv import load_dotenv
87
from openai import AsyncOpenAI
98
from pydantic import BaseModel, Field
@@ -17,19 +16,24 @@
1716
load_dotenv(override=True)
1817
API_HOST = os.getenv("API_HOST", "github")
1918

20-
if API_HOST == "github":
21-
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
22-
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
23-
elif API_HOST == "azure":
24-
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
19+
async_credential = None
20+
if API_HOST == "azure":
21+
async_credential = DefaultAzureCredential()
22+
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
2523
client = AsyncOpenAI(
2624
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
2725
api_key=token_provider,
2826
)
2927
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
28+
elif API_HOST == "github":
29+
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
30+
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
3031
elif API_HOST == "ollama":
3132
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
3233
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
34+
else:
35+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
36+
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
3337

3438

3539
class Flight(BaseModel):
@@ -76,7 +80,12 @@ class Seat(BaseModel):
7680
seat_preference_agent = Agent(
7781
model,
7882
output_type=Seat | Failed,
79-
system_prompt=("Extract the user's seat preference. " "Seats A and F are window seats. " "Row 1 is the front row and has extra leg room. " "Rows 14, and 20 also have extra leg room. "),
83+
system_prompt=(
84+
"Extract the user's seat preference. "
85+
"Seats A and F are window seats. "
86+
"Row 1 is the front row and has extra leg room. "
87+
"Rows 14, and 20 also have extra leg room. "
88+
),
8089
)
8190

8291

@@ -100,6 +109,9 @@ async def main():
100109
seat_preference = await find_seat()
101110
print(f"Seat preference: {seat_preference}")
102111

112+
if async_credential:
113+
await async_credential.close()
114+
103115

104116
if __name__ == "__main__":
105117
asyncio.run(main())

0 commit comments

Comments
 (0)