Skip to content

Commit 43e3ac8

Browse files
committed
fix(py): Fixed xAI samples
1 parent 2b015cd commit 43e3ac8

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

py/plugins/xai/src/genkit/plugins/xai/models.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -242,27 +242,30 @@ def _to_xai_messages(self, messages: list[Message]) -> list[chat_pb2.Message]:
242242
raise ValueError('xAI models require a URL for media parts.')
243243
content.append(chat_pb2.Content(image_url=chat_pb2.ImageURL(url=actual_part.media.url)))
244244
elif isinstance(actual_part, ToolRequestPart):
245+
# Serialize tool arguments safely
246+
try:
247+
arguments = json.dumps(actual_part.tool_request.input)
248+
except (TypeError, ValueError):
249+
arguments = str(actual_part.tool_request.input)
245250
tool_calls.append(
246251
chat_pb2.ToolCall(
247252
id=actual_part.tool_request.ref,
248-
type=chat_pb2.ToolCallType.FUNCTION,
249-
function=chat_pb2.Function(
253+
function=chat_pb2.FunctionCall(
250254
name=actual_part.tool_request.name,
251-
arguments=actual_part.tool_request.input,
255+
arguments=arguments,
252256
),
253257
)
254258
)
255259
elif isinstance(actual_part, ToolResponsePart):
256-
result.append(
257-
chat_pb2.Message(
258-
role=chat_pb2.MessageRole.ROLE_TOOL,
259-
tool_call_id=actual_part.tool_response.ref,
260-
content=[chat_pb2.Content(text=str(actual_part.tool_response.output))],
261-
)
262-
)
260+
# xAI doesn't support tool response messages in conversation history
263261
continue
264262

265-
pb_message = chat_pb2.Message(role=role, content=content)
263+
# Add empty content for messages that would otherwise be empty
264+
if not content:
265+
content.append(chat_pb2.Content(text=''))
266+
267+
pb_message = chat_pb2.Message(role=role)
268+
pb_message.content.extend(content)
266269
if tool_calls:
267270
pb_message.tool_calls.extend(tool_calls)
268271

py/samples/xai-hello/src/main.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ class WeatherInput(BaseModel):
4545

4646

4747
class CalculatorInput(BaseModel):
48-
operation: str = Field(description='Math operation: add, subtract, multiply, divide')
49-
a: float = Field(description='First number')
50-
b: float = Field(description='Second number')
48+
num1: float = Field(description='First number')
49+
num2: float = Field(description='Second number')
50+
operation: str = Field(description='Operation: +, -, *, /')
5151

5252

5353
@ai.tool()
5454
def get_weather(input: WeatherInput) -> dict:
5555
weather_data = {
56-
'New York': {'temp': 15, 'condition': 'cloudy', 'humidity': 65},
57-
'London': {'temp': 12, 'condition': 'rainy', 'humidity': 78},
58-
'Tokyo': {'temp': 20, 'condition': 'sunny', 'humidity': 55},
59-
'Paris': {'temp': 14, 'condition': 'partly cloudy', 'humidity': 60},
56+
'New York': {'temp': 15, 'condition': 'cloudy'},
57+
'London': {'temp': 12, 'condition': 'rainy'},
58+
'Tokyo': {'temp': 20, 'condition': 'sunny'},
59+
'Paris': {'temp': 14, 'condition': 'windy'},
6060
}
6161

6262
location = input.location.title()
@@ -74,22 +74,22 @@ def get_weather(input: WeatherInput) -> dict:
7474

7575
@ai.tool()
7676
def calculate(input: CalculatorInput) -> dict:
77-
operations = {
78-
'add': lambda a, b: a + b,
79-
'subtract': lambda a, b: a - b,
80-
'multiply': lambda a, b: a * b,
81-
'divide': lambda a, b: a / b if b != 0 else None,
82-
}
77+
a, op, b = input.num1, input.operation, input.num2
78+
79+
if op not in ['+', '-', '*', '/']:
80+
return {'error': f'Unknown operator: {op}'}
8381

84-
operation = input.operation.lower()
85-
if operation not in operations:
86-
return {'error': f'Unknown operation: {operation}'}
82+
if op == '/':
83+
if b == 0:
84+
return {'error': 'Division by zero'}
85+
result = a / b
86+
else:
87+
result = {'+': a + b, '-': a - b, '*': a * b}[op]
8788

88-
result = operations[operation](input.a, input.b)
8989
return {
90-
'operation': operation,
91-
'a': input.a,
92-
'b': input.b,
90+
'num1': a,
91+
'operation': op,
92+
'num2': b,
9393
'result': result,
9494
}
9595

@@ -120,23 +120,19 @@ async def say_hi_with_config(name: str) -> str:
120120

121121
@ai.flow()
122122
async def weather_flow(location: str) -> str:
123-
weather_data = get_weather(WeatherInput(location=location))
124-
return (
125-
f'Weather in {location}: {weather_data.get("temp")}°{weather_data.get("unit")}, {weather_data.get("condition")}'
123+
response = await ai.generate(
124+
prompt=f'What is the weather in {location}? Be concise and only provide the weather information.',
125+
tools=['get_weather'],
126126
)
127+
return response.text
127128

128129

129130
@ai.flow()
130-
async def calculator_flow(expression: str) -> str:
131-
parts = expression.split('_')
132-
if len(parts) < 3:
133-
return 'Invalid expression format. Use: operation_a_b (e.g., add_5_3)'
134-
135-
operation, a, b = parts[0], float(parts[1]), float(parts[2])
136-
result = calculate(CalculatorInput(operation=operation, a=a, b=b))
131+
async def calculator_flow(input: CalculatorInput) -> str:
132+
result = calculate(input)
137133
if 'error' in result:
138134
return f'Error: {result["error"]}'
139-
return f'{operation.title()}({a}, {b}) = {result.get("result")}'
135+
return f'{result["num1"]} {result["operation"]} {result["num2"]} = {result["result"]}'
140136

141137

142138
async def main():
@@ -152,7 +148,7 @@ async def main():
152148
result = await weather_flow('New York')
153149
logger.info('weather_flow', result=result)
154150

155-
result = await calculator_flow('add_5_3')
151+
result = await calculator_flow(CalculatorInput(num1=5.0, operation='+', num2=3.0))
156152
logger.info('calculator_flow', result=result)
157153

158154

0 commit comments

Comments
 (0)