Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions py/plugins/xai/src/genkit/plugins/xai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,27 +242,30 @@ def _to_xai_messages(self, messages: list[Message]) -> list[chat_pb2.Message]:
raise ValueError('xAI models require a URL for media parts.')
content.append(chat_pb2.Content(image_url=chat_pb2.ImageURL(url=actual_part.media.url)))
elif isinstance(actual_part, ToolRequestPart):
# Serialize tool arguments safely
try:
arguments = json.dumps(actual_part.tool_request.input)
except (TypeError, ValueError):
arguments = str(actual_part.tool_request.input)
tool_calls.append(
chat_pb2.ToolCall(
id=actual_part.tool_request.ref,
type=chat_pb2.ToolCallType.FUNCTION,
function=chat_pb2.Function(
function=chat_pb2.FunctionCall(
name=actual_part.tool_request.name,
arguments=actual_part.tool_request.input,
arguments=arguments,
),
)
)
elif isinstance(actual_part, ToolResponsePart):
result.append(
chat_pb2.Message(
role=chat_pb2.MessageRole.ROLE_TOOL,
tool_call_id=actual_part.tool_response.ref,
content=[chat_pb2.Content(text=str(actual_part.tool_response.output))],
)
)
# xAI doesn't support tool response messages in conversation history
continue

pb_message = chat_pb2.Message(role=role, content=content)
# Add empty content for messages that would otherwise be empty
if not content:
content.append(chat_pb2.Content(text=''))

pb_message = chat_pb2.Message(role=role)
pb_message.content.extend(content)
if tool_calls:
pb_message.tool_calls.extend(tool_calls)

Expand Down
65 changes: 32 additions & 33 deletions py/samples/xai-hello/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ class WeatherInput(BaseModel):


class CalculatorInput(BaseModel):
operation: str = Field(description='Math operation: add, subtract, multiply, divide')
a: float = Field(description='First number')
b: float = Field(description='Second number')
num1: float = Field(description='First number')
num2: float = Field(description='Second number')
operation: str = Field(description='Operation: +, -, *, /')


@ai.tool()
def get_weather(input: WeatherInput) -> dict:
weather_data = {
'New York': {'temp': 15, 'condition': 'cloudy', 'humidity': 65},
'London': {'temp': 12, 'condition': 'rainy', 'humidity': 78},
'Tokyo': {'temp': 20, 'condition': 'sunny', 'humidity': 55},
'Paris': {'temp': 14, 'condition': 'partly cloudy', 'humidity': 60},
'New York': {'temp': 15, 'condition': 'cloudy'},
'London': {'temp': 12, 'condition': 'rainy'},
'Tokyo': {'temp': 20, 'condition': 'sunny'},
'Paris': {'temp': 14, 'condition': 'windy'},
}

location = input.location.title()
Expand All @@ -74,22 +74,25 @@ def get_weather(input: WeatherInput) -> dict:

@ai.tool()
def calculate(input: CalculatorInput) -> dict:
operations = {
'add': lambda a, b: a + b,
'subtract': lambda a, b: a - b,
'multiply': lambda a, b: a * b,
'divide': lambda a, b: a / b if b != 0 else None,
}

operation = input.operation.lower()
if operation not in operations:
return {'error': f'Unknown operation: {operation}'}
a, op, b = input.num1, input.operation, input.num2

if op == '+':
result = a + b
elif op == '-':
result = a - b
elif op == '*':
result = a * b
elif op == '/':
if b == 0:
return {'error': 'Division by zero'}
result = a / b
else:
return {'error': f'Unknown operator: {op}'}

result = operations[operation](input.a, input.b)
return {
'operation': operation,
'a': input.a,
'b': input.b,
'num1': a,
'operation': op,
'num2': b,
'result': result,
}

Expand Down Expand Up @@ -120,23 +123,19 @@ async def say_hi_with_config(name: str) -> str:

@ai.flow()
async def weather_flow(location: str) -> str:
weather_data = get_weather(WeatherInput(location=location))
return (
f'Weather in {location}: {weather_data.get("temp")}°{weather_data.get("unit")}, {weather_data.get("condition")}'
response = await ai.generate(
prompt=f'What is the weather in {location}? Be concise and only provide the weather information.',
tools=['get_weather'],
)
return response.text


@ai.flow()
async def calculator_flow(expression: str) -> str:
parts = expression.split('_')
if len(parts) < 3:
return 'Invalid expression format. Use: operation_a_b (e.g., add_5_3)'

operation, a, b = parts[0], float(parts[1]), float(parts[2])
result = calculate(CalculatorInput(operation=operation, a=a, b=b))
async def calculator_flow(input: CalculatorInput) -> str:
result = calculate(input)
if 'error' in result:
return f'Error: {result["error"]}'
return f'{operation.title()}({a}, {b}) = {result.get("result")}'
return f'{result["num1"]} {result["operation"]} {result["num2"]} = {result["result"]}'


async def main():
Expand All @@ -152,7 +151,7 @@ async def main():
result = await weather_flow('New York')
logger.info('weather_flow', result=result)

result = await calculator_flow('add_5_3')
result = await calculator_flow(CalculatorInput(num1=5.0, operation='+', num2=3.0))
logger.info('calculator_flow', result=result)


Expand Down
Loading