Skip to content

Commit 4513c8f

Browse files
committed
Fix pre-commit issues
1 parent ab2ff90 commit 4513c8f

File tree

3 files changed

+162
-154
lines changed

3 files changed

+162
-154
lines changed

examples/pydantic_ai_examples/flight_booking_grok.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
"""
55

66
import datetime
7+
import os
78
from dataclasses import dataclass
89
from typing import Literal
9-
import os
10+
1011
import logfire
1112
from pydantic import BaseModel, Field
1213
from rich.prompt import Prompt
@@ -22,22 +23,22 @@
2223
logfire.instrument_httpx()
2324

2425
# Configure for xAI API
25-
xai_api_key = os.getenv("XAI_API_KEY")
26+
xai_api_key = os.getenv('XAI_API_KEY')
2627
if not xai_api_key:
27-
raise ValueError("XAI_API_KEY environment variable is required")
28+
raise ValueError('XAI_API_KEY environment variable is required')
2829

2930

3031
# Create the model using the new GrokModelpwd
31-
model = GrokModel("grok-4-fast-non-reasoning", api_key=xai_api_key)
32+
model = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key)
3233

3334

3435
class FlightDetails(BaseModel):
3536
"""Details of the most suitable flight."""
3637

3738
flight_number: str
3839
price: int
39-
origin: str = Field(description="Three-letter airport code")
40-
destination: str = Field(description="Three-letter airport code")
40+
origin: str = Field(description='Three-letter airport code')
41+
destination: str = Field(description='Three-letter airport code')
4142
date: datetime.date
4243

4344

@@ -58,15 +59,17 @@ class Deps:
5859
model=model,
5960
output_type=FlightDetails | NoFlightFound, # type: ignore
6061
retries=4,
61-
system_prompt=("Your job is to find the cheapest flight for the user on the given date. "),
62+
system_prompt=(
63+
'Your job is to find the cheapest flight for the user on the given date. '
64+
),
6265
)
6366

6467

6568
# This agent is responsible for extracting flight details from web page text.
6669
extraction_agent = Agent(
6770
model=model,
6871
output_type=list[FlightDetails],
69-
system_prompt="Extract all the flight details from the given text.",
72+
system_prompt='Extract all the flight details from the given text.',
7073
)
7174

7275

@@ -75,7 +78,7 @@ async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
7578
"""Get details of all flights."""
7679
# we pass the usage to the search agent so requests within this agent are counted
7780
result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage)
78-
logfire.info("found {flight_count} flights", flight_count=len(result.output))
81+
logfire.info('found {flight_count} flights', flight_count=len(result.output))
7982
return result.output
8083

8184

@@ -89,23 +92,25 @@ async def validate_output(
8992

9093
errors: list[str] = []
9194
if output.origin != ctx.deps.req_origin:
92-
errors.append(f"Flight should have origin {ctx.deps.req_origin}, not {output.origin}")
95+
errors.append(
96+
f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}'
97+
)
9398
if output.destination != ctx.deps.req_destination:
9499
errors.append(
95-
f"Flight should have destination {ctx.deps.req_destination}, not {output.destination}"
100+
f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}'
96101
)
97102
if output.date != ctx.deps.req_date:
98-
errors.append(f"Flight should be on {ctx.deps.req_date}, not {output.date}")
103+
errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}')
99104

100105
if errors:
101-
raise ModelRetry("\n".join(errors))
106+
raise ModelRetry('\n'.join(errors))
102107
else:
103108
return output
104109

105110

106111
class SeatPreference(BaseModel):
107112
row: int = Field(ge=1, le=30)
108-
seat: Literal["A", "B", "C", "D", "E", "F"]
113+
seat: Literal['A', 'B', 'C', 'D', 'E', 'F']
109114

110115

111116
class Failed(BaseModel):
@@ -118,9 +123,9 @@ class Failed(BaseModel):
118123
output_type=SeatPreference | Failed,
119124
system_prompt=(
120125
"Extract the user's seat preference. "
121-
"Seats A and F are window seats. "
122-
"Row 1 is the front row and has extra leg room. "
123-
"Rows 14, and 20 also have extra leg room. "
126+
'Seats A and F are window seats. '
127+
'Row 1 is the front row and has extra leg room. '
128+
'Rows 14, and 20 also have extra leg room. '
124129
),
125130
)
126131

@@ -184,46 +189,46 @@ class Failed(BaseModel):
184189
async def main():
185190
deps = Deps(
186191
web_page_text=flights_web_page,
187-
req_origin="SFO",
188-
req_destination="ANC",
192+
req_origin='SFO',
193+
req_destination='ANC',
189194
req_date=datetime.date(2025, 1, 10),
190195
)
191196
message_history: list[ModelMessage] | None = None
192197
usage: RunUsage = RunUsage()
193198
# run the agent until a satisfactory flight is found
194199
while True:
195200
result = await search_agent.run(
196-
f"Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}",
201+
f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}',
197202
deps=deps,
198203
usage=usage,
199204
message_history=message_history,
200205
usage_limits=usage_limits,
201206
)
202207
if isinstance(result.output, NoFlightFound):
203-
print("No flight found")
208+
print('No flight found')
204209
break
205210
else:
206211
flight = result.output
207-
print(f"Flight found: {flight}")
212+
print(f'Flight found: {flight}')
208213
answer = Prompt.ask(
209-
"Do you want to buy this flight, or keep searching? (buy/*search)",
210-
choices=["buy", "search", ""],
214+
'Do you want to buy this flight, or keep searching? (buy/*search)',
215+
choices=['buy', 'search', ''],
211216
show_choices=False,
212217
)
213-
if answer == "buy":
218+
if answer == 'buy':
214219
seat = await find_seat(usage)
215220
await buy_tickets(flight, seat)
216221
break
217222
else:
218223
message_history = result.all_messages(
219-
output_tool_return_content="Please suggest another flight"
224+
output_tool_return_content='Please suggest another flight'
220225
)
221226

222227

223228
async def find_seat(usage: RunUsage) -> SeatPreference:
224229
message_history: list[ModelMessage] | None = None
225230
while True:
226-
answer = Prompt.ask("What seat would you like?")
231+
answer = Prompt.ask('What seat would you like?')
227232

228233
result = await seat_preference_agent.run(
229234
answer,
@@ -234,15 +239,15 @@ async def find_seat(usage: RunUsage) -> SeatPreference:
234239
if isinstance(result.output, SeatPreference):
235240
return result.output
236241
else:
237-
print("Could not understand seat preference. Please try again.")
242+
print('Could not understand seat preference. Please try again.')
238243
message_history = result.all_messages()
239244

240245

241246
async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference):
242-
print(f"Purchasing flight {flight_details=!r} {seat=!r}...")
247+
print(f'Purchasing flight {flight_details=!r} {seat=!r}...')
243248

244249

245-
if __name__ == "__main__":
250+
if __name__ == '__main__':
246251
import asyncio
247252

248253
asyncio.run(main())

0 commit comments

Comments
 (0)