Skip to content

Commit 2b58f59

Browse files
committed
Update stock analytics agent to stream responses
1 parent 0f1c113 commit 2b58f59

File tree

2 files changed

+32
-104
lines changed

2 files changed

+32
-104
lines changed

examples/pydantic_ai_examples/stock_analysis_agent.py

Lines changed: 31 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
"""Example of using Grok's server-side tools (web_search, code_execution) with a local function.
1+
"""Example of using Grok's server-side web_search tool.
22
33
This agent:
4-
1. Uses web_search to find the best performing NASDAQ stock over the last week
5-
2. Uses code_execution to project the price using linear regression
6-
3. Calls a local function project_price with the results
4+
1. Uses web_search to find the hottest performing stock yesterday
5+
2. Provides buy analysis for the user
76
"""
87

98
import os
10-
from datetime import datetime
119

1210
import logfire
1311
from pydantic import BaseModel, Field
1412

1513
from pydantic_ai import (
1614
Agent,
1715
BuiltinToolCallPart,
18-
CodeExecutionTool,
19-
ModelResponse,
20-
RunContext,
2116
WebSearchTool,
2217
)
2318
from pydantic_ai.models.grok import GrokModel
@@ -35,126 +30,62 @@
3530
model = GrokModel('grok-4-fast', api_key=xai_api_key)
3631

3732

38-
class StockProjection(BaseModel):
39-
"""Projection of stock price at year end."""
33+
class StockAnalysis(BaseModel):
34+
"""Analysis of top performing stock."""
4035

4136
stock_symbol: str = Field(description='Stock ticker symbol')
4237
current_price: float = Field(description='Current stock price')
43-
projected_price: float = Field(description='Projected price at end of year')
44-
analysis: str = Field(description='Brief analysis of the projection')
38+
buy_analysis: str = Field(description='Brief analysis for whether to buy the stock')
4539

4640

47-
# This agent uses server-side tools to research and analyze stocks
48-
stock_analysis_agent = Agent[None, StockProjection](
41+
# This agent uses server-side web search to research stocks
42+
stock_analysis_agent = Agent[None, StockAnalysis](
4943
model=model,
50-
output_type=StockProjection,
51-
builtin_tools=[
52-
WebSearchTool(), # Server-side web search
53-
CodeExecutionTool(), # Server-side code execution
54-
],
44+
output_type=StockAnalysis,
45+
builtin_tools=[WebSearchTool()],
5546
system_prompt=(
5647
'You are a stock analysis assistant. '
57-
'Use web_search to find recent stock performance data on NASDAQ. '
58-
'Use code_execution to perform linear regression for price projection. '
59-
'After analysis, call project_price with your findings.'
48+
'Use web_search to find the hottest performing stock from yesterday on NASDAQ. '
49+
'Provide the current price and a brief buy analysis explaining whether this is a good buy.'
6050
),
6151
)
6252

6353

64-
@stock_analysis_agent.tool
65-
def project_price(ctx: RunContext[None], stock: str, price: float) -> str:
66-
"""Record the projected stock price.
67-
68-
This is a local/client-side function that gets called with the analysis results.
69-
70-
Args:
71-
ctx: The run context (not used in this function)
72-
stock: Stock ticker symbol
73-
price: Projected price at end of year
74-
"""
75-
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
76-
logfire.info(
77-
'Stock projection recorded',
78-
stock=stock,
79-
projected_price=price,
80-
timestamp=timestamp,
81-
)
82-
print('\n📊 PROJECTION RECORDED:')
83-
print(f' Stock: {stock}')
84-
print(f' Projected End-of-Year Price: ${price:.2f}')
85-
print(f' Timestamp: {timestamp}\n')
86-
87-
return f'Projection for {stock} at ${price:.2f} has been recorded successfully.'
88-
89-
9054
async def main():
9155
"""Run the stock analysis agent."""
92-
query = (
93-
'Can you find me the best performing stock on the NASDAQ over the last week, '
94-
'and return the price project for the end of the year using a simple linear regression. '
95-
)
56+
query = 'What was the hottest performing stock on NASDAQ yesterday?'
9657

9758
print('🔍 Starting stock analysis...\n')
9859
print(f'Query: {query}\n')
9960

100-
result = await stock_analysis_agent.run(query)
101-
102-
# Track which builtin tools were used
103-
web_search_count = 0
104-
code_execution_count = 0
105-
106-
for message in result.all_messages():
107-
if isinstance(message, ModelResponse):
61+
async with stock_analysis_agent.run_stream(query) as result:
62+
# Stream responses as they happen
63+
async for message, _is_last in result.stream_responses():
10864
for part in message.parts:
10965
if isinstance(part, BuiltinToolCallPart):
110-
if 'web_search' in part.tool_name or 'browse' in part.tool_name:
111-
web_search_count += 1
112-
logfire.info(
113-
'Server-side web_search tool called',
114-
tool_name=part.tool_name,
115-
tool_call_id=part.tool_call_id,
116-
)
117-
elif 'code_execution' in part.tool_name:
118-
code_execution_count += 1
119-
logfire.info(
120-
'Server-side code_execution tool called',
121-
tool_name=part.tool_name,
122-
tool_call_id=part.tool_call_id,
123-
code=part.args_as_dict().get('code', 'N/A')
124-
if part.args
125-
else 'N/A',
126-
)
127-
128-
print('\n✅ Analysis complete!')
129-
print('\n🔧 Server-Side Tools Used:')
130-
print(f' Web Search calls: {web_search_count}')
131-
print(f' Code Execution calls: {code_execution_count}')
132-
133-
print(f'\nStock: {result.output.stock_symbol}')
134-
print(f'Current Price: ${result.output.current_price:.2f}')
135-
print(f'Projected Year-End Price: ${result.output.projected_price:.2f}')
136-
print(f'\nAnalysis: {result.output.analysis}')
137-
138-
# Get the final response message for metadata
139-
print(result.all_messages())
140-
final_message = result.all_messages()[-1]
141-
if isinstance(final_message, ModelResponse):
142-
print('\n🆔 Response Metadata:')
143-
if final_message.provider_response_id:
144-
print(f' Response ID: {final_message.provider_response_id}')
145-
if final_message.model_name:
146-
print(f' Model: {final_message.model_name}')
147-
if final_message.timestamp:
148-
print(f' Timestamp: {final_message.timestamp}')
66+
print(f'🔧 Server-side tool: {part.tool_name}\n')
67+
68+
# Access output after streaming is complete
69+
output = await result.get_output()
70+
71+
print('\n✅ Analysis complete!\n')
72+
73+
print(f'📊 Top Stock: {output.stock_symbol}')
74+
print(f'💰 Current Price: ${output.current_price:.2f}')
75+
print(f'\n📈 Buy Analysis:\n{output.buy_analysis}')
14976

15077
# Show usage statistics
15178
usage = result.usage()
152-
print('\n📈 Usage Statistics:')
79+
print('\n📊 Usage Statistics:')
15380
print(f' Requests: {usage.requests}')
15481
print(f' Input Tokens: {usage.input_tokens}')
15582
print(f' Output Tokens: {usage.output_tokens}')
15683
print(f' Total Tokens: {usage.total_tokens}')
15784

85+
# Show server-side tools usage if available
86+
if usage.details and 'server_side_tools_used' in usage.details:
87+
print(f' Server-Side Tools: {usage.details["server_side_tools_used"]}')
88+
15889

15990
if __name__ == '__main__':
16091
import asyncio

tests/models/test_grok.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ async def test_grok_builtin_web_search_tool(allow_model_requests: None, xai_api_
750750

751751
result = await agent.run('Return just the day of week for the date of Jan 1 in 2026?')
752752
assert result.output
753-
assert result.output.lower() == 'thursday'
753+
assert 'thursday' in result.output.lower()
754754

755755
# Verify that server-side tools were used
756756
usage = result.usage()
@@ -820,9 +820,6 @@ async def test_grok_builtin_multiple_tools(allow_model_requests: None, xai_api_k
820820
messages = result.all_messages()
821821
assert len(messages) >= 2
822822

823-
# The model should use both tools (basic validation that registration works)
824-
# TODO: Add validation for built-in tool usage once response parsing is fully tested
825-
826823

827824
@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)')
828825
async def test_grok_builtin_tools_with_custom_tools(allow_model_requests: None, xai_api_key: str):

0 commit comments

Comments
 (0)