diff --git a/py/plugins/xai/src/genkit/plugins/xai/models.py b/py/plugins/xai/src/genkit/plugins/xai/models.py index a8fff11874..e3f1a8b4f2 100644 --- a/py/plugins/xai/src/genkit/plugins/xai/models.py +++ b/py/plugins/xai/src/genkit/plugins/xai/models.py @@ -242,27 +242,30 @@ def _to_xai_messages(self, messages: list[Message]) -> list[chat_pb2.Message]: raise ValueError('xAI models require a URL for media parts.') content.append(chat_pb2.Content(image_url=chat_pb2.ImageURL(url=actual_part.media.url))) elif isinstance(actual_part, ToolRequestPart): + # Serialize tool arguments safely + try: + arguments = json.dumps(actual_part.tool_request.input) + except (TypeError, ValueError): + arguments = str(actual_part.tool_request.input) tool_calls.append( chat_pb2.ToolCall( id=actual_part.tool_request.ref, - type=chat_pb2.ToolCallType.FUNCTION, - function=chat_pb2.Function( + function=chat_pb2.FunctionCall( name=actual_part.tool_request.name, - arguments=actual_part.tool_request.input, + arguments=arguments, ), ) ) elif isinstance(actual_part, ToolResponsePart): - result.append( - chat_pb2.Message( - role=chat_pb2.MessageRole.ROLE_TOOL, - tool_call_id=actual_part.tool_response.ref, - content=[chat_pb2.Content(text=str(actual_part.tool_response.output))], - ) - ) + # xAI doesn't support tool response messages in conversation history continue - pb_message = chat_pb2.Message(role=role, content=content) + # Add empty content for messages that would otherwise be empty + if not content: + content.append(chat_pb2.Content(text='')) + + pb_message = chat_pb2.Message(role=role) + pb_message.content.extend(content) if tool_calls: pb_message.tool_calls.extend(tool_calls) diff --git a/py/samples/xai-hello/src/main.py b/py/samples/xai-hello/src/main.py index d1c7ccdf06..ea7f77a58e 100755 --- a/py/samples/xai-hello/src/main.py +++ b/py/samples/xai-hello/src/main.py @@ -45,18 +45,18 @@ class WeatherInput(BaseModel): class CalculatorInput(BaseModel): - operation: str = Field(description='Math operation: add, subtract, multiply, divide') - a: float = Field(description='First number') - b: float = Field(description='Second number') + num1: float = Field(description='First number') + num2: float = Field(description='Second number') + operation: str = Field(description='Operation: +, -, *, /') @ai.tool() def get_weather(input: WeatherInput) -> dict: weather_data = { - 'New York': {'temp': 15, 'condition': 'cloudy', 'humidity': 65}, - 'London': {'temp': 12, 'condition': 'rainy', 'humidity': 78}, - 'Tokyo': {'temp': 20, 'condition': 'sunny', 'humidity': 55}, - 'Paris': {'temp': 14, 'condition': 'partly cloudy', 'humidity': 60}, + 'New York': {'temp': 15, 'condition': 'cloudy'}, + 'London': {'temp': 12, 'condition': 'rainy'}, + 'Tokyo': {'temp': 20, 'condition': 'sunny'}, + 'Paris': {'temp': 14, 'condition': 'windy'}, } location = input.location.title() @@ -74,22 +74,25 @@ def get_weather(input: WeatherInput) -> dict: @ai.tool() def calculate(input: CalculatorInput) -> dict: - operations = { - 'add': lambda a, b: a + b, - 'subtract': lambda a, b: a - b, - 'multiply': lambda a, b: a * b, - 'divide': lambda a, b: a / b if b != 0 else None, - } - - operation = input.operation.lower() - if operation not in operations: - return {'error': f'Unknown operation: {operation}'} + a, op, b = input.num1, input.operation, input.num2 + + if op == '+': + result = a + b + elif op == '-': + result = a - b + elif op == '*': + result = a * b + elif op == '/': + if b == 0: + return {'error': 'Division by zero'} + result = a / b + else: + return {'error': f'Unknown operator: {op}'} - result = operations[operation](input.a, input.b) return { - 'operation': operation, - 'a': input.a, - 'b': input.b, + 'num1': a, + 'operation': op, + 'num2': b, 'result': result, } @@ -120,23 +123,19 @@ async def say_hi_with_config(name: str) -> str: @ai.flow() async def weather_flow(location: str) -> str: - weather_data = get_weather(WeatherInput(location=location)) - return ( - f'Weather in {location}: {weather_data.get("temp")}°{weather_data.get("unit")}, {weather_data.get("condition")}' + response = await ai.generate( + prompt=f'What is the weather in {location}? Be concise and only provide the weather information.', + tools=['get_weather'], ) + return response.text @ai.flow() -async def calculator_flow(expression: str) -> str: - parts = expression.split('_') - if len(parts) < 3: - return 'Invalid expression format. Use: operation_a_b (e.g., add_5_3)' - - operation, a, b = parts[0], float(parts[1]), float(parts[2]) - result = calculate(CalculatorInput(operation=operation, a=a, b=b)) +async def calculator_flow(input: CalculatorInput) -> str: + result = calculate(input) if 'error' in result: return f'Error: {result["error"]}' - return f'{operation.title()}({a}, {b}) = {result.get("result")}' + return f'{result["num1"]} {result["operation"]} {result["num2"]} = {result["result"]}' async def main(): @@ -152,7 +151,7 @@ async def main(): result = await weather_flow('New York') logger.info('weather_flow', result=result) - result = await calculator_flow('add_5_3') + result = await calculator_flow(CalculatorInput(num1=5.0, operation='+', num2=3.0)) logger.info('calculator_flow', result=result)