|
1 | | -"""Example of using Grok's server-side tools (web_search, code_execution) with a local function. |
| 1 | +"""Example of using Grok's server-side web_search tool. |
2 | 2 |
|
3 | 3 | This agent: |
4 | | -1. Uses web_search to find the best performing NASDAQ stock over the last week |
5 | | -2. Uses code_execution to project the price using linear regression |
6 | | -3. Calls a local function project_price with the results |
| 4 | +1. Uses web_search to find the hottest performing stock yesterday |
| 5 | +2. Provides buy analysis for the user |
7 | 6 | """ |
8 | 7 |
|
9 | 8 | import os |
10 | | -from datetime import datetime |
11 | 9 |
|
12 | 10 | import logfire |
13 | 11 | from pydantic import BaseModel, Field |
14 | 12 |
|
15 | 13 | from pydantic_ai import ( |
16 | 14 | Agent, |
17 | 15 | BuiltinToolCallPart, |
18 | | - CodeExecutionTool, |
19 | | - ModelResponse, |
20 | | - RunContext, |
21 | 16 | WebSearchTool, |
22 | 17 | ) |
23 | 18 | from pydantic_ai.models.grok import GrokModel |
|
35 | 30 | model = GrokModel('grok-4-fast', api_key=xai_api_key) |
36 | 31 |
|
37 | 32 |
|
38 | | -class StockProjection(BaseModel): |
39 | | - """Projection of stock price at year end.""" |
| 33 | +class StockAnalysis(BaseModel): |
| 34 | + """Analysis of top performing stock.""" |
40 | 35 |
|
41 | 36 | stock_symbol: str = Field(description='Stock ticker symbol') |
42 | 37 | current_price: float = Field(description='Current stock price') |
43 | | - projected_price: float = Field(description='Projected price at end of year') |
44 | | - analysis: str = Field(description='Brief analysis of the projection') |
| 38 | + buy_analysis: str = Field(description='Brief analysis for whether to buy the stock') |
45 | 39 |
|
46 | 40 |
|
47 | | -# This agent uses server-side tools to research and analyze stocks |
48 | | -stock_analysis_agent = Agent[None, StockProjection]( |
| 41 | +# This agent uses server-side web search to research stocks |
| 42 | +stock_analysis_agent = Agent[None, StockAnalysis]( |
49 | 43 | model=model, |
50 | | - output_type=StockProjection, |
51 | | - builtin_tools=[ |
52 | | - WebSearchTool(), # Server-side web search |
53 | | - CodeExecutionTool(), # Server-side code execution |
54 | | - ], |
| 44 | + output_type=StockAnalysis, |
| 45 | + builtin_tools=[WebSearchTool()], |
55 | 46 | system_prompt=( |
56 | 47 | 'You are a stock analysis assistant. ' |
57 | | - 'Use web_search to find recent stock performance data on NASDAQ. ' |
58 | | - 'Use code_execution to perform linear regression for price projection. ' |
59 | | - 'After analysis, call project_price with your findings.' |
| 48 | + 'Use web_search to find the hottest performing stock from yesterday on NASDAQ. ' |
| 49 | + 'Provide the current price and a brief buy analysis explaining whether this is a good buy.' |
60 | 50 | ), |
61 | 51 | ) |
62 | 52 |
|
63 | 53 |
|
64 | | -@stock_analysis_agent.tool |
65 | | -def project_price(ctx: RunContext[None], stock: str, price: float) -> str: |
66 | | - """Record the projected stock price. |
67 | | -
|
68 | | - This is a local/client-side function that gets called with the analysis results. |
69 | | -
|
70 | | - Args: |
71 | | - ctx: The run context (not used in this function) |
72 | | - stock: Stock ticker symbol |
73 | | - price: Projected price at end of year |
74 | | - """ |
75 | | - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
76 | | - logfire.info( |
77 | | - 'Stock projection recorded', |
78 | | - stock=stock, |
79 | | - projected_price=price, |
80 | | - timestamp=timestamp, |
81 | | - ) |
82 | | - print('\n📊 PROJECTION RECORDED:') |
83 | | - print(f' Stock: {stock}') |
84 | | - print(f' Projected End-of-Year Price: ${price:.2f}') |
85 | | - print(f' Timestamp: {timestamp}\n') |
86 | | - |
87 | | - return f'Projection for {stock} at ${price:.2f} has been recorded successfully.' |
88 | | - |
89 | | - |
90 | 54 | async def main(): |
91 | 55 | """Run the stock analysis agent.""" |
92 | | - query = ( |
93 | | - 'Can you find me the best performing stock on the NASDAQ over the last week, ' |
94 | | - 'and return the price project for the end of the year using a simple linear regression. ' |
95 | | - ) |
| 56 | + query = 'What was the hottest performing stock on NASDAQ yesterday?' |
96 | 57 |
|
97 | 58 | print('🔍 Starting stock analysis...\n') |
98 | 59 | print(f'Query: {query}\n') |
99 | 60 |
|
100 | | - result = await stock_analysis_agent.run(query) |
101 | | - |
102 | | - # Track which builtin tools were used |
103 | | - web_search_count = 0 |
104 | | - code_execution_count = 0 |
105 | | - |
106 | | - for message in result.all_messages(): |
107 | | - if isinstance(message, ModelResponse): |
| 61 | + async with stock_analysis_agent.run_stream(query) as result: |
| 62 | + # Stream responses as they happen |
| 63 | + async for message, _is_last in result.stream_responses(): |
108 | 64 | for part in message.parts: |
109 | 65 | if isinstance(part, BuiltinToolCallPart): |
110 | | - if 'web_search' in part.tool_name or 'browse' in part.tool_name: |
111 | | - web_search_count += 1 |
112 | | - logfire.info( |
113 | | - 'Server-side web_search tool called', |
114 | | - tool_name=part.tool_name, |
115 | | - tool_call_id=part.tool_call_id, |
116 | | - ) |
117 | | - elif 'code_execution' in part.tool_name: |
118 | | - code_execution_count += 1 |
119 | | - logfire.info( |
120 | | - 'Server-side code_execution tool called', |
121 | | - tool_name=part.tool_name, |
122 | | - tool_call_id=part.tool_call_id, |
123 | | - code=part.args_as_dict().get('code', 'N/A') |
124 | | - if part.args |
125 | | - else 'N/A', |
126 | | - ) |
127 | | - |
128 | | - print('\n✅ Analysis complete!') |
129 | | - print('\n🔧 Server-Side Tools Used:') |
130 | | - print(f' Web Search calls: {web_search_count}') |
131 | | - print(f' Code Execution calls: {code_execution_count}') |
132 | | - |
133 | | - print(f'\nStock: {result.output.stock_symbol}') |
134 | | - print(f'Current Price: ${result.output.current_price:.2f}') |
135 | | - print(f'Projected Year-End Price: ${result.output.projected_price:.2f}') |
136 | | - print(f'\nAnalysis: {result.output.analysis}') |
137 | | - |
138 | | - # Get the final response message for metadata |
139 | | - print(result.all_messages()) |
140 | | - final_message = result.all_messages()[-1] |
141 | | - if isinstance(final_message, ModelResponse): |
142 | | - print('\n🆔 Response Metadata:') |
143 | | - if final_message.provider_response_id: |
144 | | - print(f' Response ID: {final_message.provider_response_id}') |
145 | | - if final_message.model_name: |
146 | | - print(f' Model: {final_message.model_name}') |
147 | | - if final_message.timestamp: |
148 | | - print(f' Timestamp: {final_message.timestamp}') |
| 66 | + print(f'🔧 Server-side tool: {part.tool_name}\n') |
| 67 | + |
| 68 | + # Access output after streaming is complete |
| 69 | + output = await result.get_output() |
| 70 | + |
| 71 | + print('\n✅ Analysis complete!\n') |
| 72 | + |
| 73 | + print(f'📊 Top Stock: {output.stock_symbol}') |
| 74 | + print(f'💰 Current Price: ${output.current_price:.2f}') |
| 75 | + print(f'\n📈 Buy Analysis:\n{output.buy_analysis}') |
149 | 76 |
|
150 | 77 | # Show usage statistics |
151 | 78 | usage = result.usage() |
152 | | - print('\n📈 Usage Statistics:') |
| 79 | + print('\n📊 Usage Statistics:') |
153 | 80 | print(f' Requests: {usage.requests}') |
154 | 81 | print(f' Input Tokens: {usage.input_tokens}') |
155 | 82 | print(f' Output Tokens: {usage.output_tokens}') |
156 | 83 | print(f' Total Tokens: {usage.total_tokens}') |
157 | 84 |
|
| 85 | + # Show server-side tools usage if available |
| 86 | + if usage.details and 'server_side_tools_used' in usage.details: |
| 87 | + print(f' Server-Side Tools: {usage.details["server_side_tools_used"]}') |
| 88 | + |
158 | 89 |
|
159 | 90 | if __name__ == '__main__': |
160 | 91 | import asyncio |
|
0 commit comments