Skip to content

Commit 3ceed3e

Browse files
Added agent demo code
1 parent ebca0ed commit 3ceed3e

File tree

5 files changed

+1294
-15
lines changed

5 files changed

+1294
-15
lines changed

example_package/ai_agent_demo.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import os
2+
from datetime import UTC, datetime, timedelta
3+
from enum import Enum, StrEnum
4+
5+
import httpx
6+
import logfire
7+
import ollama
8+
from pydantic import BaseModel, field_validator
9+
from pydantic_ai import Agent
10+
11+
# ------------------------ Observability ------------------------------
12+
logfire.configure(send_to_logfire=False)
13+
logfire.instrument_pydantic_ai()
14+
logfire.instrument_httpx()
15+
16+
# ------------------------- Agent -------------------------------------
17+
os.environ["OLLAMA_BASE_URL"] = "http://localhost:11434/v1"
18+
19+
MODEL = "gpt-oss:20b"
20+
21+
try:
22+
ollama.pull(MODEL)
23+
except Exception as exc:
24+
raise RuntimeError("ollama not installed on your system.") from exc
25+
26+
27+
agent = Agent(
28+
model=f"ollama:{MODEL}",
29+
instructions=("Be concise."),
30+
)
31+
32+
33+
# ----------------------- Tools -------------------------------------
34+
class Market(Enum):
35+
FTSE_100 = "ftse_100"
36+
SNP_500 = "s&p_500"
37+
DAX = "dax"
38+
HANG_SENG = "hang_seng"
39+
STRAITS_TIMES = "strait_times"
40+
NIKKEI = "nikkei"
41+
42+
@property
43+
def symbol(self) -> str:
44+
"""Returns the symbol for the market."""
45+
match self:
46+
case Market.FTSE_100:
47+
return "^FTSE"
48+
case Market.SNP_500:
49+
return "^GSPC"
50+
case Market.DAX:
51+
return "^GDAXI"
52+
case Market.HANG_SENG:
53+
return "^HSI"
54+
case Market.STRAITS_TIMES:
55+
return "^STI"
56+
case Market.NIKKEI:
57+
return "^N225"
58+
59+
raise ValueError("Unknown Market")
60+
61+
62+
class City(StrEnum):
63+
NEW_YORK = "New York"
64+
LONDON = "London"
65+
TOKYO = "Tokyo"
66+
SINGAPORE = "Singapore"
67+
FRANKFURT = "Frankfurt"
68+
HONG_KONG = "Hong Kong"
69+
70+
71+
class MarketMetaData(BaseModel):
72+
currency: str
73+
symbol: str
74+
exchangeName: str
75+
fullExchangeName: str
76+
instrumentType: str
77+
timezone: str
78+
regularMarketPrice: float
79+
fiftyTwoWeekHigh: float
80+
fiftyTwoWeekLow: float
81+
regularMarketDayHigh: float
82+
regularMarketDayLow: float
83+
longName: str
84+
85+
86+
class MarketIndicators(BaseModel):
87+
low: list[float]
88+
close: list[float]
89+
volume: list[int]
90+
close: list[float]
91+
92+
93+
class AdjIndicators(BaseModel):
94+
adjclose: list[float]
95+
96+
97+
class Indicators(BaseModel):
98+
quote: list[MarketIndicators]
99+
adjclose: list[AdjIndicators]
100+
101+
102+
class StockMarketData(BaseModel):
103+
meta: MarketMetaData
104+
timestamp: list[str]
105+
indicators: Indicators
106+
107+
@field_validator("timestamp", mode="before")
108+
@classmethod
109+
def _validate_timestamps(cls, timestamp: list[int]) -> list[str]:
110+
dt = [datetime.fromtimestamp(ts, tz=UTC) for ts in timestamp]
111+
return [dt.strftime("%Y-%m-%d %H:%M:%S") for dt in dt]
112+
113+
114+
class DailyUnits(BaseModel):
115+
temperature_2m_max: str
116+
temperature_2m_min: str
117+
uv_index_max: str
118+
sunshine_duration: str
119+
daylight_duration: str
120+
rain_sum: str
121+
showers_sum: str
122+
snowfall_sum: str
123+
wind_speed_10m_max: str
124+
wind_gusts_10m_max: str
125+
126+
127+
class DailyWeatherData(BaseModel):
128+
time: list[str]
129+
temperature_2m_max: list[float]
130+
temperature_2m_min: list[float]
131+
uv_index_max: list[float]
132+
sunshine_duration: list[float]
133+
daylight_duration: list[float]
134+
rain_sum: list[float]
135+
showers_sum: list[float]
136+
wind_speed_10m_max: list[float]
137+
wind_gusts_10m_max: list[float]
138+
139+
140+
class WeatherData(BaseModel):
141+
latitude: float
142+
longitude: float
143+
elevation: float
144+
timezone: str
145+
timezone_abbreviation: str
146+
elevation: float
147+
daily: DailyWeatherData
148+
149+
150+
# ------------------------ Tools ---------------------------------------------
151+
@agent.tool_plain
152+
async def get_market_data(market: Market, timerange: int) -> StockMarketData:
153+
"""Return market data for a given market.
154+
155+
Args:
156+
market: The market to query.
157+
timerange: The number of days to look back.
158+
interval: Time interval
159+
"""
160+
headers = {
161+
"User-Agent": (
162+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
163+
"AppleWebKit/537.36 (KHTML, like Gecko) "
164+
"Chrome/127.0.0.0 Safari/537.36"
165+
),
166+
"Accept": "application/json, text/javascript, */*; q=0.01",
167+
"Accept-Language": "en-US,en;q=0.9",
168+
"Referer": "https://finance.yahoo.com/",
169+
}
170+
171+
end = datetime.now(tz=UTC)
172+
start = end - timedelta(days=timerange)
173+
start_ts, end_ts = int(start.timestamp()), int(end.timestamp())
174+
query_url = f"https://query1.finance.yahoo.com/v8/finance/chart/{market.symbol}?period1={start_ts}&period2={end_ts}&interval=1d"
175+
async with httpx.AsyncClient(headers=headers, timeout=10.0) as client:
176+
r = await client.get(query_url)
177+
r.raise_for_status()
178+
data = r.json()
179+
180+
return StockMarketData.model_validate(data["chart"]["result"][0])
181+
182+
183+
def _geocode_city(name: str) -> tuple[float, float]:
184+
url = "https://geocoding-api.open-meteo.com/v1/search"
185+
params = {"name": name, "count": 1, "language": "en", "format": "json"}
186+
with httpx.Client(timeout=10) as x:
187+
r = x.get(url, params=params)
188+
r.raise_for_status()
189+
j = r.json()
190+
if not j.get("results"):
191+
raise ValueError(f"city not found: {name}")
192+
lat = float(j["results"][0]["latitude"])
193+
lon = float(j["results"][0]["longitude"])
194+
return lat, lon
195+
196+
197+
@agent.tool_plain
198+
async def city_weather(city: City, timerange: int) -> WeatherData:
199+
"""Return recent weather for a city."""
200+
201+
latitude, longitude = _geocode_city(city)
202+
203+
url = "https://api.open-meteo.com/v1/forecast"
204+
params = {
205+
"latitude": latitude,
206+
"longitude": longitude,
207+
"daily": [
208+
"temperature_2m_max",
209+
"temperature_2m_min",
210+
"sunshine_duration",
211+
"rain_sum",
212+
"showers_sum",
213+
"wind_speed_10m_max",
214+
],
215+
"past_days": timerange,
216+
"forecast_days": 0,
217+
}
218+
with httpx.Client(timeout=10) as x:
219+
r = x.get(url, params=params)
220+
r.raise_for_status()
221+
j = r.json()
222+
223+
return WeatherData.model_validate(j)
224+
225+
226+
# ------------------------ Runner --------------------------------------------
227+
async def run_agent(prompt: str) -> str:
228+
"""Run the agent with a given prompt."""
229+
result = await agent.run(prompt)
230+
return result.output

example_package/main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
from rich import print as pprint
99

10-
from example_package.json_loading_demo import (
10+
from example_package.ai_agent_demo import run_agent
11+
from example_package.json_loading_benchmark import (
1112
DATASET_CONFIGS,
1213
JsonLoader,
1314
TypedDecoder,
1415
benchmark_json_loading,
1516
benchmark_typed_decoding,
1617
)
17-
from example_package.mandelbrot import ExecutionMode, time_mandelbrot
18+
from example_package.mandelbrot_benchmark import ExecutionMode, time_mandelbrot
1819

1920

2021
def hello_world() -> None:
@@ -71,5 +72,12 @@ def mandelbrot_performance_demo() -> None:
7172
pprint(f"Elapsed time ({mode.value}): {perf_counter() - t0:.3f} seconds")
7273

7374

75+
def ai_agent_demo() -> None:
76+
"""Demos a Pydantic AI agent workflow with some dummy tools."""
77+
prompt = "Using your tools, analyze if cities with bad weather recently experienced worse stock market performance."
78+
response = asyncio.run(run_agent(prompt))
79+
print(response)
80+
81+
7482
if __name__ == "__main__":
7583
pprint("Start your projects main entrypoint here")

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ requires-python = ">=3.13,<3.14"
1111
dependencies = [
1212
"aiofiles>=25",
1313
"httpx>=0.28",
14-
"logfire>=4",
14+
"logfire[httpx]>=4",
1515
"msgspec>=0.19",
1616
"numba>=0.62",
1717
"numpy>=2",
18+
"ollama>=0.6",
1819
"orjson>=3",
1920
"pydantic>=2",
21+
"pydantic-ai>=1",
2022
"rich>=14",
2123
]
2224

@@ -25,6 +27,7 @@ hello-world = "example_package.main:hello_world"
2527
mandelbrot-demo = "example_package.main:mandelbrot_performance_demo"
2628
json-loading-demo = "example_package.main:json_loading_performance_demo"
2729
typed-decoding-demo = "example_package.main:typed_decoding_demo"
30+
ai-agent-demo = "example_package.main:ai_agent_demo"
2831

2932
[dependency-groups]
3033
dev = [

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ lint.ignore = [
5858
"D213", # Allow multi-line docstring summary to start at the second line
5959
"N806", # Allow uppercased variable names
6060
"N812", # Allow mixed-case variable names
61+
"N815", # Allow camelCase attributes
6162
"S101", # Allow use of assert statements
6263
"SIM108", # Allow if/else blocks (instead of requiring ternary operations)
6364

0 commit comments

Comments
 (0)