44"""
55
66import datetime
7+ import os
78from dataclasses import dataclass
89from typing import Literal
9- import os
10+
1011import logfire
1112from pydantic import BaseModel , Field
1213from rich .prompt import Prompt
2223logfire .instrument_httpx ()
2324
2425# Configure for xAI API
25- xai_api_key = os .getenv (" XAI_API_KEY" )
26+ xai_api_key = os .getenv (' XAI_API_KEY' )
2627if not xai_api_key :
27- raise ValueError (" XAI_API_KEY environment variable is required" )
28+ raise ValueError (' XAI_API_KEY environment variable is required' )
2829
2930
3031# Create the model using the new GrokModelpwd
31- model = GrokModel (" grok-4-fast-non-reasoning" , api_key = xai_api_key )
32+ model = GrokModel (' grok-4-fast-non-reasoning' , api_key = xai_api_key )
3233
3334
3435class FlightDetails (BaseModel ):
3536 """Details of the most suitable flight."""
3637
3738 flight_number : str
3839 price : int
39- origin : str = Field (description = " Three-letter airport code" )
40- destination : str = Field (description = " Three-letter airport code" )
40+ origin : str = Field (description = ' Three-letter airport code' )
41+ destination : str = Field (description = ' Three-letter airport code' )
4142 date : datetime .date
4243
4344
@@ -58,15 +59,17 @@ class Deps:
5859 model = model ,
5960 output_type = FlightDetails | NoFlightFound , # type: ignore
6061 retries = 4 ,
61- system_prompt = ("Your job is to find the cheapest flight for the user on the given date. " ),
62+ system_prompt = (
63+ 'Your job is to find the cheapest flight for the user on the given date. '
64+ ),
6265)
6366
6467
6568# This agent is responsible for extracting flight details from web page text.
6669extraction_agent = Agent (
6770 model = model ,
6871 output_type = list [FlightDetails ],
69- system_prompt = " Extract all the flight details from the given text." ,
72+ system_prompt = ' Extract all the flight details from the given text.' ,
7073)
7174
7275
@@ -75,7 +78,7 @@ async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
7578 """Get details of all flights."""
7679 # we pass the usage to the search agent so requests within this agent are counted
7780 result = await extraction_agent .run (ctx .deps .web_page_text , usage = ctx .usage )
78- logfire .info (" found {flight_count} flights" , flight_count = len (result .output ))
81+ logfire .info (' found {flight_count} flights' , flight_count = len (result .output ))
7982 return result .output
8083
8184
@@ -89,23 +92,25 @@ async def validate_output(
8992
9093 errors : list [str ] = []
9194 if output .origin != ctx .deps .req_origin :
92- errors .append (f"Flight should have origin { ctx .deps .req_origin } , not { output .origin } " )
95+ errors .append (
96+ f'Flight should have origin { ctx .deps .req_origin } , not { output .origin } '
97+ )
9398 if output .destination != ctx .deps .req_destination :
9499 errors .append (
95- f" Flight should have destination { ctx .deps .req_destination } , not { output .destination } "
100+ f' Flight should have destination { ctx .deps .req_destination } , not { output .destination } '
96101 )
97102 if output .date != ctx .deps .req_date :
98- errors .append (f" Flight should be on { ctx .deps .req_date } , not { output .date } " )
103+ errors .append (f' Flight should be on { ctx .deps .req_date } , not { output .date } ' )
99104
100105 if errors :
101- raise ModelRetry (" \n " .join (errors ))
106+ raise ModelRetry (' \n ' .join (errors ))
102107 else :
103108 return output
104109
105110
106111class SeatPreference (BaseModel ):
107112 row : int = Field (ge = 1 , le = 30 )
108- seat : Literal ["A" , "B" , "C" , "D" , "E" , "F" ]
113+ seat : Literal ['A' , 'B' , 'C' , 'D' , 'E' , 'F' ]
109114
110115
111116class Failed (BaseModel ):
@@ -118,9 +123,9 @@ class Failed(BaseModel):
118123 output_type = SeatPreference | Failed ,
119124 system_prompt = (
120125 "Extract the user's seat preference. "
121- " Seats A and F are window seats. "
122- " Row 1 is the front row and has extra leg room. "
123- " Rows 14, and 20 also have extra leg room. "
126+ ' Seats A and F are window seats. '
127+ ' Row 1 is the front row and has extra leg room. '
128+ ' Rows 14, and 20 also have extra leg room. '
124129 ),
125130)
126131
@@ -184,46 +189,46 @@ class Failed(BaseModel):
184189async def main ():
185190 deps = Deps (
186191 web_page_text = flights_web_page ,
187- req_origin = " SFO" ,
188- req_destination = " ANC" ,
192+ req_origin = ' SFO' ,
193+ req_destination = ' ANC' ,
189194 req_date = datetime .date (2025 , 1 , 10 ),
190195 )
191196 message_history : list [ModelMessage ] | None = None
192197 usage : RunUsage = RunUsage ()
193198 # run the agent until a satisfactory flight is found
194199 while True :
195200 result = await search_agent .run (
196- f" Find me a flight from { deps .req_origin } to { deps .req_destination } on { deps .req_date } " ,
201+ f' Find me a flight from { deps .req_origin } to { deps .req_destination } on { deps .req_date } ' ,
197202 deps = deps ,
198203 usage = usage ,
199204 message_history = message_history ,
200205 usage_limits = usage_limits ,
201206 )
202207 if isinstance (result .output , NoFlightFound ):
203- print (" No flight found" )
208+ print (' No flight found' )
204209 break
205210 else :
206211 flight = result .output
207- print (f" Flight found: { flight } " )
212+ print (f' Flight found: { flight } ' )
208213 answer = Prompt .ask (
209- " Do you want to buy this flight, or keep searching? (buy/*search)" ,
210- choices = [" buy" , " search" , "" ],
214+ ' Do you want to buy this flight, or keep searching? (buy/*search)' ,
215+ choices = [' buy' , ' search' , '' ],
211216 show_choices = False ,
212217 )
213- if answer == " buy" :
218+ if answer == ' buy' :
214219 seat = await find_seat (usage )
215220 await buy_tickets (flight , seat )
216221 break
217222 else :
218223 message_history = result .all_messages (
219- output_tool_return_content = " Please suggest another flight"
224+ output_tool_return_content = ' Please suggest another flight'
220225 )
221226
222227
223228async def find_seat (usage : RunUsage ) -> SeatPreference :
224229 message_history : list [ModelMessage ] | None = None
225230 while True :
226- answer = Prompt .ask (" What seat would you like?" )
231+ answer = Prompt .ask (' What seat would you like?' )
227232
228233 result = await seat_preference_agent .run (
229234 answer ,
@@ -234,15 +239,15 @@ async def find_seat(usage: RunUsage) -> SeatPreference:
234239 if isinstance (result .output , SeatPreference ):
235240 return result .output
236241 else :
237- print (" Could not understand seat preference. Please try again." )
242+ print (' Could not understand seat preference. Please try again.' )
238243 message_history = result .all_messages ()
239244
240245
241246async def buy_tickets (flight_details : FlightDetails , seat : SeatPreference ):
242- print (f" Purchasing flight { flight_details = !r} { seat = !r} ..." )
247+ print (f' Purchasing flight { flight_details = !r} { seat = !r} ...' )
243248
244249
245- if __name__ == " __main__" :
250+ if __name__ == ' __main__' :
246251 import asyncio
247252
248253 asyncio .run (main ())
0 commit comments