|
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import random |
| 5 | +from datetime import datetime |
| 6 | + |
| 7 | +import azure.identity |
| 8 | +from dotenv import load_dotenv |
| 9 | +from openai import AsyncAzureOpenAI, AsyncOpenAI |
| 10 | +from pydantic_ai import Agent |
| 11 | +from pydantic_ai.models.openai import OpenAIModel |
| 12 | +from pydantic_ai.providers.openai import OpenAIProvider |
| 13 | +from rich.logging import RichHandler |
| 14 | + |
| 15 | +# Setup logging with rich |
| 16 | +logging.basicConfig(level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) |
| 17 | +logger = logging.getLogger("weekend_planner") |
| 18 | + |
| 19 | + |
| 20 | +# Setup the OpenAI client to use either Azure OpenAI or GitHub Models |
| 21 | +load_dotenv(override=True) |
| 22 | +API_HOST = os.getenv("API_HOST", "github") |
| 23 | + |
| 24 | +if API_HOST == "github": |
| 25 | + client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com") |
| 26 | + model = OpenAIModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client)) |
| 27 | +elif API_HOST == "azure": |
| 28 | + token_provider = azure.identity.get_bearer_token_provider(azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") |
| 29 | + client = AsyncAzureOpenAI( |
| 30 | + api_version=os.environ["AZURE_OPENAI_VERSION"], |
| 31 | + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], |
| 32 | + azure_ad_token_provider=token_provider, |
| 33 | + ) |
| 34 | + model = OpenAIModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client)) |
| 35 | + |
| 36 | + |
| 37 | +def get_weather(city: str) -> str: |
| 38 | + logger.info(f"Getting weather for {city}") |
| 39 | + if random.random() < 0.05: |
| 40 | + return { |
| 41 | + "city": city, |
| 42 | + "temperature": 72, |
| 43 | + "description": "Sunny", |
| 44 | + } |
| 45 | + else: |
| 46 | + return { |
| 47 | + "city": city, |
| 48 | + "temperature": 60, |
| 49 | + "description": "Rainy", |
| 50 | + } |
| 51 | + |
| 52 | + |
| 53 | +def get_activities(city: str, date: str) -> list: |
| 54 | + logger.info(f"Getting activities for {city} on {date}") |
| 55 | + return [ |
| 56 | + {"name": "Hiking", "location": city}, |
| 57 | + {"name": "Beach", "location": city}, |
| 58 | + {"name": "Museum", "location": city}, |
| 59 | + ] |
| 60 | + |
| 61 | +def get_current_date() -> str: |
| 62 | + logger.info("Getting current date") |
| 63 | + return datetime.now().strftime("%Y-%m-%d") |
| 64 | + |
| 65 | + |
| 66 | +agent = Agent( |
| 67 | + model, |
| 68 | + system_prompt="You help users plan their weekends and choose the best activities for the given weather. If an activity would be unpleasant in the weather, don't suggest it. Include the date of the weekend in your response.", |
| 69 | + tools=[get_weather, get_activities, get_current_date], |
| 70 | +) |
| 71 | + |
| 72 | +async def main(): |
| 73 | + result = await agent.run("what can I do for funzies this weekend in Seattle?") |
| 74 | + print(result.output) |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + logger.setLevel(logging.INFO) |
| 79 | + asyncio.run(main()) |
0 commit comments