diff --git a/examples/callModel-typed-tool-calling.example.ts b/examples/callModel-typed-tool-calling.example.ts new file mode 100644 index 00000000..2f383ac2 --- /dev/null +++ b/examples/callModel-typed-tool-calling.example.ts @@ -0,0 +1,169 @@ +/* + * Example: Typed Tool Calling with callModel + * + * This example demonstrates how to use the tool() function for + * fully-typed tool definitions where execute params, return types, and event + * types are automatically inferred from Zod schemas. + * + * Tool types are auto-detected based on configuration: + * - Generator tool: When `eventSchema` is provided + * - Regular tool: When `execute` is a function (no `eventSchema`) + * - Manual tool: When `execute: false` is set + * + * To run this example from the examples directory: + * npm run build && npx tsx callModel-typed-tool-calling.example.ts + */ + +import dotenv from "dotenv"; +dotenv.config(); + +import { OpenRouter, tool } from "../src/index.js"; +import z from "zod"; + +const openRouter = new OpenRouter({ + apiKey: process.env["OPENROUTER_API_KEY"] ?? "", +}); + +// Create a typed regular tool using tool() +// The execute function params are automatically typed as z.infer +// The return type is enforced based on outputSchema +const weatherTool = tool({ + name: "get_weather", + description: "Get the current weather for a location", + inputSchema: z.object({ + location: z.string().describe("The city and country, e.g. San Francisco, CA"), + }), + outputSchema: z.object({ + temperature: z.number(), + description: z.string(), + }), + // params is automatically typed as { location: string } + execute: async (params) => { + console.log(`Getting weather for: ${params.location}`); + // Return type is enforced as { temperature: number; description: string } + return { + temperature: 20, + description: "Sunny", + }; + }, +}); + +// Create a generator tool with typed progress events by providing eventSchema +// The eventSchema triggers generator mode - execute becomes an async generator +const searchTool = tool({ + name: "search_database", + description: "Search database with progress updates", + inputSchema: z.object({ + query: z.string().describe("The search query"), + }), + eventSchema: z.object({ + progress: z.number(), + message: z.string(), + }), + outputSchema: z.object({ + results: z.array(z.string()), + totalFound: z.number(), + }), + // execute is a generator that yields typed progress events + execute: async function* (params) { + console.log(`Searching for: ${params.query}`); + // Each yield is typed as { progress: number; message: string } + yield { progress: 25, message: "Searching..." }; + yield { progress: 50, message: "Processing results..." }; + yield { progress: 75, message: "Almost done..." }; + // Final result is typed as { results: string[]; totalFound: number } + yield { progress: 100, message: "Complete!" }; + }, +}); + +async function main() { + console.log("=== Typed Tool Calling Example ===\n"); + + // Use 'as const' to enable full type inference for tool calls + const result = openRouter.callModel({ + instructions: "You are a helpful assistant. Your name is Mark", + model: "openai/gpt-4o-mini", + input: "Hello! What is the weather in San Francisco?", + tools: [weatherTool] as const, + }); + + // Get text response (tools are auto-executed) + const text = await result.getText(); + console.log("Response:", text); + + console.log("\n=== Getting Tool Calls ===\n"); + + // Create a fresh request for demonstrating getToolCalls + const result2 = openRouter.callModel({ + model: "openai/gpt-4o-mini", + input: "What's the weather like in Paris?", + tools: [weatherTool] as const, + maxToolRounds: 0, // Don't auto-execute, just get the tool calls + }); + + // Tool calls are now typed based on the tool definitions! + const toolCalls = await result2.getToolCalls(); + + for (const toolCall of toolCalls) { + console.log(`Tool: ${toolCall.name}`); + // toolCall.arguments is typed as { location: string } + console.log(`Arguments:`, toolCall.arguments); + } + + console.log("\n=== Streaming Tool Calls ===\n"); + + // Create another request for demonstrating streaming + const result3 = openRouter.callModel({ + model: "openai/gpt-4o-mini", + input: "What's the weather in Tokyo?", + tools: [weatherTool] as const, + maxToolRounds: 0, + }); + + // Stream tool calls with typed arguments + for await (const toolCall of result3.getToolCallsStream()) { + console.log(`Streamed tool: ${toolCall.name}`); + // toolCall.arguments is typed based on tool definitions + console.log(`Streamed arguments:`, toolCall.arguments); + } + + console.log("\n=== Generator Tool with Typed Events ===\n"); + + // Use generator tool with typed progress events + const result4 = openRouter.callModel({ + model: "openai/gpt-4o-mini", + input: "Search for documents about TypeScript", + tools: [searchTool] as const, + }); + + // Stream events from getToolStream - events are fully typed! + for await (const event of result4.getToolStream()) { + if (event.type === "preliminary_result") { + // event.result is typed as { progress: number; message: string } + console.log(`Progress: ${event.result.progress}% - ${event.result.message}`); + } else if (event.type === "delta") { + // Tool argument deltas + process.stdout.write(event.content); + } + } + + console.log("\n=== Mixed Tools with Typed Events ===\n"); + + // Use both regular and generator tools together + const result5 = openRouter.callModel({ + model: "openai/gpt-4o-mini", + input: "First search for weather data, then get the weather in Seattle", + tools: [weatherTool, searchTool] as const, + }); + + // Events are a union of all generator tool event types + for await (const event of result5.getToolStream()) { + if (event.type === "preliminary_result") { + // event.result is typed as { progress: number; message: string } + // (only searchTool has eventSchema, so that's the event type) + console.log(`Event:`, event.result); + } + } +} + +main().catch(console.error); diff --git a/src/funcs/call-model.ts b/src/funcs/call-model.ts index a90ed0da..38893b25 100644 --- a/src/funcs/call-model.ts +++ b/src/funcs/call-model.ts @@ -1,120 +1,18 @@ -import type { OpenRouterCore } from "../core.js"; -import type { RequestOptions } from "../lib/sdks.js"; -import type { Tool, MaxToolRounds } from "../lib/tool-types.js"; -import type * as models from "../models/index.js"; +import type { OpenRouterCore } from '../core.js'; +import type { RequestOptions } from '../lib/sdks.js'; +import type { MaxToolRounds, Tool } from '../lib/tool-types.js'; +import type * as models from '../models/index.js'; -import { ModelResult } from "../lib/model-result.js"; -import { convertToolsToAPIFormat } from "../lib/tool-executor.js"; +import { ModelResult } from '../lib/model-result.js'; +import { convertToolsToAPIFormat } from '../lib/tool-executor.js'; /** - * Checks if a message looks like a Claude-style message + * Input type for callModel function */ -function isClaudeStyleMessage(msg: any): msg is models.ClaudeMessageParam { - if (!msg || typeof msg !== 'object') return false; - - // Check if it has a role field that's user or assistant - const role = msg.role; - if (role !== 'user' && role !== 'assistant') return false; - - // Check if content is an array with Claude-style content blocks - if (Array.isArray(msg.content)) { - return msg.content.some((block: any) => - block && - typeof block === 'object' && - block.type && - // Claude content block types (not OpenRouter types) - (block.type === 'text' || block.type === 'image' || block.type === 'tool_use' || block.type === 'tool_result') - ); - } - - return false; -} - -/** - * Converts Claude-style content blocks to OpenRouter format - */ -function convertClaudeContentBlock( - block: models.ClaudeContentBlockParam -): models.ResponseInputText | models.ResponseInputImage | null { - if (!block || typeof block !== 'object' || !('type' in block)) { - return null; - } - - switch (block.type) { - case 'text': { - const textBlock = block as models.ClaudeTextBlockParam; - return { - type: 'input_text', - text: textBlock.text, - }; - } - case 'image': { - const imageBlock = block as models.ClaudeImageBlockParam; - if (imageBlock.source.type === 'url') { - return { - type: 'input_image', - detail: 'auto', - imageUrl: imageBlock.source.url, - }; - } else if (imageBlock.source.type === 'base64') { - const dataUri = `data:${imageBlock.source.media_type};base64,${imageBlock.source.data}`; - return { - type: 'input_image', - detail: 'auto', - imageUrl: dataUri, - }; - } - return null; - } - case 'tool_use': - case 'tool_result': - // tool_use and tool_result are not handled here as they map to different input types - return null; - default: - return null; - } -} - -/** - * Converts a Claude-style message to OpenRouter EasyInputMessage format - */ -function convertClaudeMessage(msg: models.ClaudeMessageParam): models.OpenResponsesEasyInputMessage { - const { role, content } = msg; - - if (typeof content === 'string') { - return { - role: role === 'user' ? 'user' : 'assistant', - content, - }; - } - - // Convert array of content blocks - const convertedBlocks: (models.ResponseInputText | models.ResponseInputImage)[] = []; - for (const block of content) { - const converted = convertClaudeContentBlock(block); - if (converted) { - convertedBlocks.push(converted); - } - } - - // If all blocks were text, concatenate them into a string - const allText = convertedBlocks.every(b => b.type === 'input_text'); - if (allText) { - const text = convertedBlocks - .map(b => (b as models.ResponseInputText).text) - .join(''); - return { - role: role === 'user' ? 'user' : 'assistant', - content: text, - }; - } - - // Otherwise, return as array - return { - role: role === 'user' ? 'user' : 'assistant', - content: convertedBlocks, - }; -} +export type CallModelInput = Omit & { + tools?: Tool[]; + maxToolRounds?: MaxToolRounds; +}; /** * Get a response with multiple consumption patterns @@ -141,36 +39,20 @@ function convertClaudeMessage(msg: models.ClaudeMessageParam): models.OpenRespon */ export function callModel( client: OpenRouterCore, - request: Omit & { - tools?: Tool[]; - maxToolRounds?: MaxToolRounds; - }, - options?: RequestOptions + request: CallModelInput, + options?: RequestOptions, ): ModelResult { const { tools, maxToolRounds, ...apiRequest } = request; - // Auto-convert Claude-style messages if detected - let processedInput = apiRequest.input; - if (Array.isArray(apiRequest.input)) { - const hasClaudeMessages = apiRequest.input.some(isClaudeStyleMessage); - if (hasClaudeMessages) { - processedInput = apiRequest.input.map((msg: any) => { - if (isClaudeStyleMessage(msg)) { - return convertClaudeMessage(msg); - } - return msg; - }); - } - } - // Convert tools to API format and extract enhanced tools if present const apiTools = tools ? convertToolsToAPIFormat(tools) : undefined; - // Build the request with converted tools and input + // Build the request with converted tools const finalRequest: models.OpenResponsesRequest = { ...apiRequest, - ...(processedInput !== undefined && { input: processedInput }), - ...(apiTools !== undefined && { tools: apiTools }), + ...(apiTools !== undefined && { + tools: apiTools, + }), }; return new ModelResult({ @@ -178,6 +60,8 @@ export function callModel( request: finalRequest, options: options ?? {}, tools: tools ?? [], - ...(maxToolRounds !== undefined && { maxToolRounds }), + ...(maxToolRounds !== undefined && { + maxToolRounds, + }), }); } diff --git a/src/index.ts b/src/index.ts index c9a3a755..e07ef423 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,3 +12,61 @@ export * from "./sdk/sdk.js"; export { fromClaudeMessages, toClaudeMessage } from "./lib/anthropic-compat.js"; export { fromChatMessages, toChatMessage } from "./lib/chat-compat.js"; export { extractUnsupportedContent, hasUnsupportedContent, getUnsupportedContentSummary } from "./lib/stream-transformers.js"; + +// Claude message types +export type { + ClaudeMessage, + ClaudeMessageParam, + ClaudeContentBlock, + ClaudeContentBlockParam, + ClaudeTextBlock, + ClaudeThinkingBlock, + ClaudeRedactedThinkingBlock, + ClaudeToolUseBlock, + ClaudeServerToolUseBlock, + ClaudeTextBlockParam, + ClaudeImageBlockParam, + ClaudeToolUseBlockParam, + ClaudeToolResultBlockParam, + ClaudeStopReason, + ClaudeUsage, + ClaudeCacheControl, + ClaudeTextCitation, + ClaudeCitationCharLocation, + ClaudeCitationPageLocation, + ClaudeCitationContentBlockLocation, + ClaudeCitationWebSearchResultLocation, + ClaudeCitationSearchResultLocation, + ClaudeBase64ImageSource, + ClaudeURLImageSource, +} from "./models/claude-message.js"; + +// Tool creation helpers +export { tool } from "./lib/tool.js"; + +// Tool types +export type { + Tool, + ToolWithExecute, + ToolWithGenerator, + ManualTool, + TurnContext, + InferToolInput, + InferToolOutput, + InferToolEvent, + InferToolEventsUnion, + TypedToolCall, + TypedToolCallUnion, + ToolStreamEvent, + ChatStreamEvent, + EnhancedResponseStreamEvent, + ToolPreliminaryResultEvent, +} from "./lib/tool-types.js"; + +export { + ToolType, + hasExecuteFunction, + isGeneratorTool, + isRegularExecuteTool, + isToolPreliminaryResultEvent, +} from "./lib/tool-types.js"; diff --git a/src/lib/anthropic-compat.ts b/src/lib/anthropic-compat.ts index 52294828..1fe68c25 100644 --- a/src/lib/anthropic-compat.ts +++ b/src/lib/anthropic-compat.ts @@ -1,22 +1,18 @@ -import type * as models from "../models/index.js"; +import type * as models from '../models/index.js'; + import { - OpenResponsesEasyInputMessageRoleUser, OpenResponsesEasyInputMessageRoleAssistant, -} from "../models/openresponseseasyinputmessage.js"; -import { - OpenResponsesInputMessageItemRoleUser, - OpenResponsesInputMessageItemRoleSystem, -} from "../models/openresponsesinputmessageitem.js"; -import { OpenResponsesFunctionCallOutputType } from "../models/openresponsesfunctioncalloutput.js"; -import { convertToClaudeMessage } from "./stream-transformers.js"; + OpenResponsesEasyInputMessageRoleUser, +} from '../models/openresponseseasyinputmessage.js'; +import { OpenResponsesFunctionCallOutputType } from '../models/openresponsesfunctioncalloutput.js'; +import { OpenResponsesInputMessageItemRoleUser, OpenResponsesInputMessageItemRoleDeveloper } from '../models/openresponsesinputmessageitem.js'; +import { convertToClaudeMessage } from './stream-transformers.js'; /** * Maps Claude role strings to OpenResponses role types */ -function mapClaudeRole( - role: "user" | "assistant" -): models.OpenResponsesEasyInputMessageRoleUnion { - if (role === "user") { +function mapClaudeRole(role: 'user' | 'assistant'): models.OpenResponsesEasyInputMessageRoleUnion { + if (role === 'user') { return OpenResponsesEasyInputMessageRoleUser.User; } return OpenResponsesEasyInputMessageRoleAssistant.Assistant; @@ -26,8 +22,8 @@ function mapClaudeRole( * Creates a properly typed OpenResponsesEasyInputMessage */ function createEasyInputMessage( - role: "user" | "assistant", - content: string + role: 'user' | 'assistant', + content: string, ): models.OpenResponsesEasyInputMessage { return { role: mapClaudeRole(role), @@ -40,7 +36,7 @@ function createEasyInputMessage( */ function createFunctionCallOutput( callId: string, - output: string + output: string, ): models.OpenResponsesFunctionCallOutput { return { type: OpenResponsesFunctionCallOutputType.FunctionCallOutput, @@ -55,6 +51,10 @@ function createFunctionCallOutput( * This function transforms ClaudeMessageParam[] (Anthropic SDK format) to * OpenResponsesInput format that can be passed directly to callModel(). * + * Note: Some Claude features are lost in conversion as OpenRouter doesn't support them: + * - cache_control on content blocks + * - is_error flag on tool_result blocks + * * @example * ```typescript * import { fromClaudeMessages } from '@openrouter/sdk'; @@ -71,7 +71,7 @@ function createFunctionCallOutput( * ``` */ export function fromClaudeMessages( - messages: models.ClaudeMessageParam[] + messages: models.ClaudeMessageParam[], ): models.OpenResponsesInput { const result: ( | models.OpenResponsesEasyInputMessage @@ -84,142 +84,150 @@ export function fromClaudeMessages( for (const msg of messages) { const { role, content } = msg; - if (typeof content === "string") { + if (typeof content === 'string') { result.push(createEasyInputMessage(role, content)); continue; } - const contentItems: ( - | models.ResponseInputText - | models.ResponseInputImage - | models.ResponseInputFile - | models.ResponseInputAudio - )[] = []; - let hasStructuredContent = false; + // Separate content blocks into categories for clearer processing + const textBlocks: models.ClaudeTextBlockParam[] = []; + const imageBlocks: models.ClaudeImageBlockParam[] = []; + const toolUseBlocks: models.ClaudeToolUseBlockParam[] = []; + const toolResultBlocks: models.ClaudeToolResultBlockParam[] = []; for (const block of content) { switch (block.type) { - case 'text': { - const textBlock = block as models.ClaudeTextBlockParam; - contentItems.push({ - type: 'input_text', - text: textBlock.text, - }); - // Note: cache_control is lost in conversion (OpenRouter doesn't support it) + case 'text': + textBlocks.push(block as models.ClaudeTextBlockParam); break; - } - - case 'image': { - const imageBlock = block as models.ClaudeImageBlockParam; - hasStructuredContent = true; - - // Convert Claude image source to OpenRouter format - if (imageBlock.source.type === 'url') { - contentItems.push({ - type: 'input_image', - detail: 'auto', - imageUrl: imageBlock.source.url, - }); - } else if (imageBlock.source.type === 'base64') { - // Base64 images: OpenRouter expects a URL, so we use data URI - const dataUri = `data:${imageBlock.source.media_type};base64,${imageBlock.source.data}`; - contentItems.push({ - type: 'input_image', - detail: 'auto', - imageUrl: dataUri, - }); - } + case 'image': + imageBlocks.push(block as models.ClaudeImageBlockParam); + break; + case 'tool_use': + toolUseBlocks.push(block as models.ClaudeToolUseBlockParam); + break; + case 'tool_result': + toolResultBlocks.push(block as models.ClaudeToolResultBlockParam); break; + default: { + const exhaustiveCheck: never = block; + throw new Error(`Unhandled content block type: ${exhaustiveCheck}`); } + } + } - case 'tool_use': { - const toolUseBlock = block as models.ClaudeToolUseBlockParam; + // Process tool use blocks first (they go directly to result) + for (const toolUseBlock of toolUseBlocks) { + result.push({ + type: 'function_call', + callId: toolUseBlock.id, + name: toolUseBlock.name, + arguments: JSON.stringify(toolUseBlock.input), + id: toolUseBlock.id, + status: 'completed', + }); + } - // Map to OpenResponsesFunctionToolCall - result.push({ - type: 'function_call', - callId: toolUseBlock.id, - name: toolUseBlock.name, - arguments: JSON.stringify(toolUseBlock.input), - id: toolUseBlock.id, - status: 'completed', // Tool use in conversation history is already completed - }); - break; + // Process tool result blocks + for (const toolResultBlock of toolResultBlocks) { + let toolOutput = ''; + + if (typeof toolResultBlock.content === 'string') { + toolOutput = toolResultBlock.content; + } else { + // Extract text and handle images separately + const textParts: string[] = []; + const imageParts: models.ClaudeImageBlockParam[] = []; + + for (const part of toolResultBlock.content) { + if (part.type === 'text') { + textParts.push(part.text); + } else if (part.type === 'image') { + imageParts.push(part); + } } - case 'tool_result': { - const toolResultBlock = block as models.ClaudeToolResultBlockParam; + toolOutput = textParts.join(''); + + // Map images to image_generation_call items + imageParts.forEach((imagePart, i) => { + let imageUrl: string; - let toolOutput = ''; - if (typeof toolResultBlock.content === 'string') { - toolOutput = toolResultBlock.content; + if (imagePart.source.type === 'url') { + imageUrl = imagePart.source.url; + } else if (imagePart.source.type === 'base64') { + imageUrl = `data:${imagePart.source.media_type};base64,${imagePart.source.data}`; } else { - // Extract text and handle images separately - const textParts: string[] = []; - const imageParts: models.ClaudeImageBlockParam[] = []; - - for (const part of toolResultBlock.content) { - if (part.type === 'text') { - textParts.push(part.text); - } else if (part.type === 'image') { - imageParts.push(part); - } - } - - toolOutput = textParts.join(''); - - // Map images to image_generation_call items - for (const imagePart of imageParts) { - const imageUrl = imagePart.source.type === 'url' - ? imagePart.source.url - : `data:${imagePart.source.media_type};base64,${imagePart.source.data}`; - - result.push({ - type: 'image_generation_call', - id: `${toolResultBlock.tool_use_id}-image-${imageParts.indexOf(imagePart)}`, - result: imageUrl, - status: 'completed', - }); - } + const exhaustiveCheck: never = imagePart.source; + throw new Error(`Unhandled image source type: ${exhaustiveCheck}`); } - // Add the function call output for the text portion - if (toolOutput || typeof toolResultBlock.content === 'string') { - result.push(createFunctionCallOutput(toolResultBlock.tool_use_id, toolOutput)); - } - break; - } + result.push({ + type: 'image_generation_call', + id: `${toolResultBlock.tool_use_id}-image-${i}`, + result: imageUrl, + status: 'completed', + }); + }); + } - default: { - const _exhaustiveCheck: never = block; - throw new Error(`Unhandled content block type: ${(_exhaustiveCheck as { type: string }).type}`); - } + // Add the function call output for the text portion (if any) + if (toolOutput.length > 0) { + result.push(createFunctionCallOutput(toolResultBlock.tool_use_id, toolOutput)); } } - // Use structured format if we have images, otherwise use simple format - if (contentItems.length > 0) { - if (hasStructuredContent) { - // Use OpenResponsesInputMessageItem for messages with images - const messageRole = role === 'user' - ? OpenResponsesInputMessageItemRoleUser.User - : role === 'assistant' - ? OpenResponsesInputMessageItemRoleSystem.System // Assistant messages treated as system in this context - : OpenResponsesInputMessageItemRoleSystem.System; + // Process text and image blocks (these become message content) + if (textBlocks.length > 0 || imageBlocks.length > 0) { + const contentItems: (models.ResponseInputText | models.ResponseInputImage)[] = []; + + // Add text blocks + for (const textBlock of textBlocks) { + contentItems.push({ + type: 'input_text', + text: textBlock.text, + }); + } + + // Add image blocks + for (const imageBlock of imageBlocks) { + let imageUrl: string; + + if (imageBlock.source.type === 'url') { + imageUrl = imageBlock.source.url; + } else if (imageBlock.source.type === 'base64') { + imageUrl = `data:${imageBlock.source.media_type};base64,${imageBlock.source.data}`; + } else { + const exhaustiveCheck: never = imageBlock.source; + throw new Error(`Unhandled image source type: ${exhaustiveCheck}`); + } + + contentItems.push({ + type: 'input_image', + detail: 'auto', + imageUrl, + }); + } + // Determine output format based on content + if (imageBlocks.length > 0) { + // Use structured format for messages with images result.push({ type: 'message', - role: messageRole, + role: + role === 'user' + ? OpenResponsesInputMessageItemRoleUser.User + : OpenResponsesInputMessageItemRoleDeveloper.Developer, content: contentItems, }); } else { - // Use simple format for text-only messages + // Use simple string format for text-only messages const textContent = contentItems .filter((item): item is models.ResponseInputText => item.type === 'input_text') - .map(item => item.text) + .map((item) => item.text) .join(''); - if (textContent) { + if (textContent.length > 0) { result.push(createEasyInputMessage(role, textContent)); } } diff --git a/src/lib/chat-compat.ts b/src/lib/chat-compat.ts index 745df5f4..d852b18a 100644 --- a/src/lib/chat-compat.ts +++ b/src/lib/chat-compat.ts @@ -41,6 +41,10 @@ function mapChatRole( return OpenResponsesEasyInputMessageRoleAssistant.Assistant; case "developer": return OpenResponsesEasyInputMessageRoleDeveloper.Developer; + default: { + const exhaustiveCheck: never = role; + throw new Error(`Unhandled role type: ${exhaustiveCheck}`); + } } } diff --git a/src/lib/stream-transformers.ts b/src/lib/stream-transformers.ts index 5a2338b2..cf65d1bb 100644 --- a/src/lib/stream-transformers.ts +++ b/src/lib/stream-transformers.ts @@ -453,7 +453,7 @@ export function responseHasToolCalls(response: models.OpenResponsesNonStreamingR * Convert OpenRouter annotations to Claude citations */ function mapAnnotationsToCitations( - annotations?: Array + annotations?: Array, ): models.ClaudeTextCitation[] | undefined { if (!annotations || annotations.length === 0) { return undefined; @@ -509,7 +509,15 @@ function mapAnnotationsToCitations( default: { const _exhaustiveCheck: never = annotation; - throw new Error(`Unhandled annotation type: ${(_exhaustiveCheck as { type: string }).type}`); + throw new Error( + `Unhandled annotation type: ${ + ( + _exhaustiveCheck as { + type: string; + } + ).type + }`, + ); } } } @@ -561,6 +569,12 @@ export function convertToClaudeMessage( for (const item of response.output) { if (!('type' in item)) { + // Handle items without type field + unsupportedContent.push({ + original_type: 'unknown', + data: item as Record, + reason: 'Output item missing type field', + }); continue; } @@ -569,6 +583,11 @@ export function convertToClaudeMessage( const msgItem = item as models.ResponsesOutputMessage; for (const part of msgItem.content) { if (!('type' in part)) { + unsupportedContent.push({ + original_type: 'unknown_message_part', + data: part as Record, + reason: 'Message content part missing type field', + }); continue; } @@ -579,15 +598,32 @@ export function convertToClaudeMessage( content.push({ type: 'text', text: textPart.text, - ...(citations && { citations }), + ...(citations && { + citations, + }), }); } else if (part.type === 'refusal') { const refusalPart = part as models.OpenAIResponsesRefusalContent; unsupportedContent.push({ original_type: 'refusal', - data: { refusal: refusalPart.refusal }, + data: { + refusal: refusalPart.refusal, + }, reason: 'Claude does not have a native refusal content type', }); + } else { + // Handle unknown message content types + unsupportedContent.push({ + original_type: `message_content_${ + ( + part as { + type: string; + } + ).type + }`, + data: part as Record, + reason: 'Unknown message content type', + }); } } break; @@ -595,12 +631,20 @@ export function convertToClaudeMessage( case 'function_call': { const fnCall = item as models.ResponsesOutputItemFunctionCall; - let parsedInput: Record = {}; + let parsedInput: Record; try { parsedInput = JSON.parse(fnCall.arguments); - } catch { - parsedInput = {}; + } catch (error) { + // Preserve raw arguments if JSON parsing fails + // Log warning in development/debug environments + if (typeof process !== 'undefined' && process.env?.['NODE_ENV'] === 'development') { + // biome-ignore lint/suspicious/noConsole: needed for debugging in development + console.warn(`Failed to parse tool call arguments for ${fnCall.name}:`, error); + } + parsedInput = { + _raw_arguments: fnCall.arguments, + }; } content.push({ @@ -646,7 +690,9 @@ export function convertToClaudeMessage( type: 'server_tool_use', id: webSearchItem.id, name: 'web_search', - input: { status: webSearchItem.status }, + input: { + status: webSearchItem.status, + }, }); break; } @@ -680,8 +726,13 @@ export function convertToClaudeMessage( } default: { - const _exhaustiveCheck: never = item; - throw new Error(`Unhandled output item type: ${(_exhaustiveCheck as { type: string }).type}`); + const exhaustiveCheck: never = item; + unsupportedContent.push({ + original_type: 'unknown_output_item', + data: exhaustiveCheck as Record, + reason: 'Unknown output item type', + }); + break; } } } @@ -700,7 +751,9 @@ export function convertToClaudeMessage( cache_creation_input_tokens: response.usage?.inputTokensDetails?.cachedTokens ?? 0, cache_read_input_tokens: 0, }, - ...(unsupportedContent.length > 0 && { unsupported_content: unsupportedContent }), + ...(unsupportedContent.length > 0 && { + unsupported_content: unsupportedContent, + }), }; } @@ -709,23 +762,19 @@ export function convertToClaudeMessage( */ export function extractUnsupportedContent( message: models.ClaudeMessage, - originalType: string + originalType: string, ): models.UnsupportedContent[] { if (!message.unsupported_content) { return []; } - return message.unsupported_content.filter( - item => item.original_type === originalType - ); + return message.unsupported_content.filter((item) => item.original_type === originalType); } /** * Check if message has any unsupported content */ -export function hasUnsupportedContent( - message: models.ClaudeMessage -): boolean { +export function hasUnsupportedContent(message: models.ClaudeMessage): boolean { return !!(message.unsupported_content && message.unsupported_content.length > 0); } @@ -733,7 +782,7 @@ export function hasUnsupportedContent( * Get summary of unsupported content types */ export function getUnsupportedContentSummary( - message: models.ClaudeMessage + message: models.ClaudeMessage, ): Record { if (!message.unsupported_content) { return {}; diff --git a/src/lib/tool-types.ts b/src/lib/tool-types.ts index a384b7e4..7fc41105 100644 --- a/src/lib/tool-types.ts +++ b/src/lib/tool-types.ts @@ -128,6 +128,58 @@ export type Tool = | ToolWithGenerator, ZodType, ZodType> | ManualTool, ZodType>; +/** + * Extracts the input type from a tool definition + */ +export type InferToolInput = T extends { function: { inputSchema: infer S } } + ? S extends ZodType + ? z.infer + : unknown + : unknown; + +/** + * Extracts the output type from a tool definition + */ +export type InferToolOutput = T extends { function: { outputSchema: infer S } } + ? S extends ZodType + ? z.infer + : unknown + : unknown; + +/** + * A tool call with typed arguments based on the tool's inputSchema + */ +export type TypedToolCall = { + id: string; + name: T extends { function: { name: infer N } } ? N : string; + arguments: InferToolInput; +}; + +/** + * Union of typed tool calls for a tuple of tools + */ +export type TypedToolCallUnion = { + [K in keyof T]: T[K] extends Tool ? TypedToolCall : never; +}[number]; + +/** + * Extracts the event type from a generator tool definition + * Returns `never` for non-generator tools + */ +export type InferToolEvent = T extends { function: { eventSchema: infer S } } + ? S extends ZodType + ? z.infer + : never + : never; + +/** + * Union of event types for all generator tools in a tuple + * Filters out non-generator tools (which return `never`) + */ +export type InferToolEventsUnion = { + [K in keyof T]: T[K] extends Tool ? InferToolEvent : never; +}[number]; + /** * Type guard to check if a tool has an execute function */ @@ -207,34 +259,39 @@ export interface APITool { /** * Tool preliminary result event emitted during generator tool execution + * @template TEvent - The event type from the tool's eventSchema */ -export type ToolPreliminaryResultEvent = { +export type ToolPreliminaryResultEvent = { type: 'tool.preliminary_result'; toolCallId: string; - result: unknown; + result: TEvent; timestamp: number; }; /** * Enhanced stream event types for getFullResponsesStream * Extends OpenResponsesStreamEvent with tool preliminary results + * @template TEvent - The event type from generator tools */ -export type EnhancedResponseStreamEvent = OpenResponsesStreamEvent | ToolPreliminaryResultEvent; +export type EnhancedResponseStreamEvent = + | OpenResponsesStreamEvent + | ToolPreliminaryResultEvent; /** * Type guard to check if an event is a tool preliminary result event */ -export function isToolPreliminaryResultEvent( - event: EnhancedResponseStreamEvent, -): event is ToolPreliminaryResultEvent { +export function isToolPreliminaryResultEvent( + event: EnhancedResponseStreamEvent, +): event is ToolPreliminaryResultEvent { return event.type === 'tool.preliminary_result'; } /** * Tool stream event types for getToolStream * Includes both argument deltas and preliminary results + * @template TEvent - The event type from generator tools */ -export type ToolStreamEvent = +export type ToolStreamEvent = | { type: 'delta'; content: string; @@ -242,14 +299,15 @@ export type ToolStreamEvent = | { type: 'preliminary_result'; toolCallId: string; - result: unknown; + result: TEvent; }; /** * Chat stream event types for getFullChatStream * Includes content deltas, completion events, and tool preliminary results + * @template TEvent - The event type from generator tools */ -export type ChatStreamEvent = +export type ChatStreamEvent = | { type: 'content.delta'; delta: string; @@ -261,7 +319,7 @@ export type ChatStreamEvent = | { type: 'tool.preliminary_result'; toolCallId: string; - result: unknown; + result: TEvent; } | { type: string; diff --git a/src/lib/tool.ts b/src/lib/tool.ts new file mode 100644 index 00000000..23445feb --- /dev/null +++ b/src/lib/tool.ts @@ -0,0 +1,262 @@ +import type { ZodObject, ZodRawShape, ZodType, z } from "zod/v4"; +import { + ToolType, + type TurnContext, + type ToolWithExecute, + type ToolWithGenerator, + type ManualTool, +} from "./tool-types.js"; + +/** + * Configuration for a regular tool with outputSchema + */ +type RegularToolConfigWithOutput< + TInput extends ZodObject, + TOutput extends ZodType, +> = { + name: string; + description?: string; + inputSchema: TInput; + outputSchema: TOutput; + eventSchema?: undefined; + execute: ( + params: z.infer, + context?: TurnContext + ) => Promise> | z.infer; +}; + +/** + * Configuration for a regular tool without outputSchema (infers return type from execute) + */ +type RegularToolConfigWithoutOutput< + TInput extends ZodObject, + TReturn, +> = { + name: string; + description?: string; + inputSchema: TInput; + outputSchema?: undefined; + eventSchema?: undefined; + execute: ( + params: z.infer, + context?: TurnContext + ) => Promise | TReturn; +}; + +/** + * Configuration for a generator tool (with eventSchema) + */ +type GeneratorToolConfig< + TInput extends ZodObject, + TEvent extends ZodType, + TOutput extends ZodType, +> = { + name: string; + description?: string; + inputSchema: TInput; + eventSchema: TEvent; + outputSchema: TOutput; + execute: ( + params: z.infer, + context?: TurnContext + ) => AsyncGenerator | z.infer>; +}; + +/** + * Configuration for a manual tool (execute: false, no eventSchema or outputSchema) + */ +type ManualToolConfig> = { + name: string; + description?: string; + inputSchema: TInput; + execute: false; +}; + +/** + * Union type for all regular tool configs + */ +type RegularToolConfig, TOutput extends ZodType, TReturn> = + | RegularToolConfigWithOutput + | RegularToolConfigWithoutOutput; + +/** + * Type guard to check if config is a generator tool config (has eventSchema) + */ +function isGeneratorConfig< + TInput extends ZodObject, + TEvent extends ZodType, + TOutput extends ZodType, + TReturn, +>( + config: + | GeneratorToolConfig + | RegularToolConfig + | ManualToolConfig +): config is GeneratorToolConfig { + return "eventSchema" in config && config.eventSchema !== undefined; +} + +/** + * Type guard to check if config is a manual tool config (execute === false) + */ +function isManualConfig, TOutput extends ZodType, TReturn>( + config: + | GeneratorToolConfig + | RegularToolConfig + | ManualToolConfig +): config is ManualToolConfig { + return config.execute === false; +} + +/** + * Creates a tool with full type inference from Zod schemas. + * + * The tool type is automatically determined based on the configuration: + * - **Generator tool**: When `eventSchema` is provided + * - **Regular tool**: When `execute` is a function (no `eventSchema`) + * - **Manual tool**: When `execute: false` is set + * + * @example Regular tool: + * ```typescript + * const weatherTool = tool({ + * name: "get_weather", + * description: "Get weather for a location", + * inputSchema: z.object({ location: z.string() }), + * outputSchema: z.object({ temperature: z.number() }), + * execute: async (params) => { + * // params is typed as { location: string } + * return { temperature: 72 }; // return type is enforced + * }, + * }); + * ``` + * + * @example Generator tool (with eventSchema): + * ```typescript + * const progressTool = tool({ + * name: "process_data", + * inputSchema: z.object({ data: z.string() }), + * eventSchema: z.object({ progress: z.number() }), + * outputSchema: z.object({ result: z.string() }), + * execute: async function* (params) { + * yield { progress: 50 }; // typed as event + * yield { result: "done" }; // typed as output + * }, + * }); + * ``` + * + * @example Manual tool (execute: false): + * ```typescript + * const manualTool = tool({ + * name: "external_action", + * inputSchema: z.object({ action: z.string() }), + * execute: false, + * }); + * ``` + */ +// Overload for generator tools (when eventSchema is provided) +export function tool< + TInput extends ZodObject, + TEvent extends ZodType, + TOutput extends ZodType, +>( + config: GeneratorToolConfig +): ToolWithGenerator; + +// Overload for manual tools (execute: false) +export function tool>( + config: ManualToolConfig +): ManualTool; + +// Overload for regular tools with outputSchema +export function tool< + TInput extends ZodObject, + TOutput extends ZodType, +>(config: RegularToolConfigWithOutput): ToolWithExecute; + +// Overload for regular tools without outputSchema (infers return type) +export function tool< + TInput extends ZodObject, + TReturn, +>(config: RegularToolConfigWithoutOutput): ToolWithExecute>; + +// Implementation +export function tool< + TInput extends ZodObject, + TEvent extends ZodType, + TOutput extends ZodType, + TReturn, +>( + config: + | GeneratorToolConfig + | RegularToolConfig + | ManualToolConfig +): + | ToolWithGenerator + | ToolWithExecute + | ToolWithExecute> + | ManualTool { + // Check for manual tool first (execute === false) + if (isManualConfig(config)) { + const fn: ManualTool["function"] = { + name: config.name, + inputSchema: config.inputSchema, + }; + + if (config.description !== undefined) { + fn.description = config.description; + } + + return { + type: ToolType.Function, + function: fn, + }; + } + + // Check for generator tool (has eventSchema) + if (isGeneratorConfig(config)) { + const fn: ToolWithGenerator["function"] = { + name: config.name, + inputSchema: config.inputSchema, + eventSchema: config.eventSchema, + outputSchema: config.outputSchema, + // The config execute allows yielding both events and output, + // but the interface only types for events (output is extracted separately) + execute: config.execute as ToolWithGenerator< + TInput, + TEvent, + TOutput + >["function"]["execute"], + }; + + if (config.description !== undefined) { + fn.description = config.description; + } + + return { + type: ToolType.Function, + function: fn, + }; + } + + // Regular tool (has execute function, no eventSchema) + // Type assertion needed because we have two overloads (with/without outputSchema) + // and the implementation needs to handle both cases + const fn = { + name: config.name, + inputSchema: config.inputSchema, + execute: config.execute, + } as ToolWithExecute["function"]; + + if (config.description !== undefined) { + fn.description = config.description; + } + + if (config.outputSchema !== undefined) { + fn.outputSchema = config.outputSchema; + } + + return { + type: ToolType.Function, + function: fn, + }; +} diff --git a/src/sdk/sdk.ts b/src/sdk/sdk.ts index dd293732..fa0219a8 100644 --- a/src/sdk/sdk.ts +++ b/src/sdk/sdk.ts @@ -20,12 +20,11 @@ import { Providers } from "./providers.js"; // #region imports import { callModel as callModelFunc, + type CallModelInput, } from "../funcs/call-model.js"; import type { ModelResult } from "../lib/model-result.js"; import type { RequestOptions } from "../lib/sdks.js"; -import { type MaxToolRounds, Tool, ToolType } from "../lib/tool-types.js"; -import type { OpenResponsesRequest } from "../models/openresponsesrequest.js"; -import type { OpenResponsesInput } from "../models/openresponsesinput.js"; +import { type MaxToolRounds, ToolType } from "../lib/tool-types.js"; export { ToolType }; export type { MaxToolRounds }; @@ -99,11 +98,7 @@ export class OpenRouter extends ClientSDK { // #region sdk-class-body callModel( - request: Omit & { - input?: OpenResponsesInput; - tools?: Tool[]; - maxToolRounds?: MaxToolRounds; - }, + request: CallModelInput, options?: RequestOptions, ): ModelResult { return callModelFunc(this, request, options); diff --git a/tests/e2e/call-model-tools.test.ts b/tests/e2e/call-model-tools.test.ts index 40f487d6..0fe1dcc2 100644 --- a/tests/e2e/call-model-tools.test.ts +++ b/tests/e2e/call-model-tools.test.ts @@ -64,8 +64,7 @@ describe('Enhanced Tool Support for callModel', () => { target: 'openapi-3.0', }); - // @ts-expect-error - description is not a property of _JSONSchema - expect(jsonSchema.properties?.location?.description).toBe( + expect(jsonSchema.properties?.location?.['description']).toBe( 'City and country e.g. Bogotá, Colombia', ); }); diff --git a/tests/e2e/call-model.test.ts b/tests/e2e/call-model.test.ts index f3349c88..aa726581 100644 --- a/tests/e2e/call-model.test.ts +++ b/tests/e2e/call-model.test.ts @@ -6,7 +6,8 @@ import type { OpenResponsesFunctionCallOutput } from '../../src/models/openrespo import { beforeAll, describe, expect, it } from 'vitest'; import { z } from 'zod/v4'; import { OpenRouter, ToolType } from '../../src/sdk/sdk.js'; -import { toChatMessage } from '../../src/lib/chat-compat.js'; +import { fromChatMessages, toChatMessage } from '../../src/lib/chat-compat.js'; +import { fromClaudeMessages } from '../../src/lib/anthropic-compat.js'; import { OpenResponsesNonStreamingResponse } from '../../src/models/openresponsesnonstreamingresponse.js'; import { OpenResponsesStreamEvent } from '../../src/models/openresponsesstreamevent.js'; @@ -28,7 +29,7 @@ describe('callModel E2E Tests', () => { it('should accept chat-style Message array as input', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'system', content: 'You are a helpful assistant.', @@ -37,7 +38,7 @@ describe('callModel E2E Tests', () => { role: 'user', content: "Say 'chat test' and nothing else.", }, - ], + ]), }); const text = await response.getText(); @@ -50,7 +51,7 @@ describe('callModel E2E Tests', () => { it('should handle multi-turn chat-style conversation', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'My favorite color is blue.', @@ -63,7 +64,7 @@ describe('callModel E2E Tests', () => { role: 'user', content: 'What is my favorite color?', }, - ], + ]), }); const text = await response.getText(); @@ -75,7 +76,7 @@ describe('callModel E2E Tests', () => { it('should handle system message in chat-style input', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'system', content: 'Always respond with exactly one word.', @@ -84,7 +85,7 @@ describe('callModel E2E Tests', () => { role: 'user', content: 'Say hello.', }, - ], + ]), }); const text = await response.getText(); @@ -96,12 +97,12 @@ describe('callModel E2E Tests', () => { it('should accept chat-style tools (ToolDefinitionJson)', async () => { const response = client.callModel({ model: 'qwen/qwen3-vl-8b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "What's the weather in Paris? Use the get_weather tool.", }, - ], + ]), tools: [ { type: ToolType.Function, @@ -137,7 +138,7 @@ describe('callModel E2E Tests', () => { it.skip('should work with chat-style messages and chat-style tools together', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.1-8b-instruct', - input: [ + input: fromChatMessages([ { role: 'system', content: 'You are a helpful assistant. Use tools when needed.', @@ -146,7 +147,7 @@ describe('callModel E2E Tests', () => { role: 'user', content: 'Get the weather in Tokyo using the weather tool.', }, - ], + ]), tools: [ { type: ToolType.Function, @@ -189,7 +190,7 @@ describe('callModel E2E Tests', () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: claudeMessages, + input: fromClaudeMessages(claudeMessages), }); const text = await response.getText(); @@ -214,7 +215,7 @@ describe('callModel E2E Tests', () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: claudeMessages, + input: fromClaudeMessages(claudeMessages), }); const text = await response.getText(); @@ -242,7 +243,7 @@ describe('callModel E2E Tests', () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: claudeMessages, + input: fromClaudeMessages(claudeMessages), }); const text = await response.getText(); @@ -270,7 +271,7 @@ describe('callModel E2E Tests', () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: claudeMessages, + input: fromClaudeMessages(claudeMessages), }); const text = await response.getText(); @@ -284,12 +285,12 @@ describe('callModel E2E Tests', () => { it('should successfully get text from a response', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'Hello, World!' and nothing else.", }, - ], + ]), }); const text = await response.getText(); @@ -303,7 +304,7 @@ describe('callModel E2E Tests', () => { it('should handle multi-turn conversations', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'My name is Bob.', @@ -316,7 +317,7 @@ describe('callModel E2E Tests', () => { role: 'user', content: 'What is my name?', }, - ], + ]), }); const text = await response.getText(); @@ -330,12 +331,12 @@ describe('callModel E2E Tests', () => { it('should successfully get a complete message from response', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'test message' and nothing else.", }, - ], + ]), }); const fullResponse = await response.getResponse(); @@ -364,12 +365,12 @@ describe('callModel E2E Tests', () => { it('should have proper message structure', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Respond with a simple greeting.', }, - ], + ]), }); const fullResponse = await response.getResponse(); @@ -425,12 +426,12 @@ describe('callModel E2E Tests', () => { it('should successfully stream text deltas', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Count from 1 to 5.', }, - ], + ]), }); const deltas: string[] = []; @@ -450,12 +451,12 @@ describe('callModel E2E Tests', () => { it('should stream progressively without waiting for completion', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Write a short poem.', }, - ], + ]), }); let firstDeltaTime: number | null = null; @@ -485,12 +486,12 @@ describe('callModel E2E Tests', () => { it('should successfully stream incremental message updates in ResponsesOutputMessage format', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'streaming test'.", }, - ], + ]), }); const messages: (ResponsesOutputMessage | OpenResponsesFunctionCallOutput)[] = []; @@ -526,12 +527,12 @@ describe('callModel E2E Tests', () => { it('should return ResponsesOutputMessage with correct shape', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'hello world'.", }, - ], + ]), }); const messages: (ResponsesOutputMessage | OpenResponsesFunctionCallOutput)[] = []; @@ -575,12 +576,12 @@ describe('callModel E2E Tests', () => { it('should include OpenResponsesFunctionCallOutput with correct shape when tools are executed', async () => { const response = client.callModel({ model: 'openai/gpt-4o-mini', - input: [ + input: fromChatMessages([ { role: 'user', content: "What's the weather in Tokyo? Use the get_weather tool.", }, - ], + ]), tools: [ { type: ToolType.Function, @@ -673,12 +674,12 @@ describe('callModel E2E Tests', () => { it('should return messages with all required fields and correct types', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Count from 1 to 3.', }, - ], + ]), }); for await (const message of response.getNewMessagesStream()) { @@ -714,12 +715,12 @@ describe('callModel E2E Tests', () => { it.skip('should successfully stream reasoning deltas for reasoning models', async () => { const response = client.callModel({ model: 'minimax/minimax-m2', - input: [ + input: fromChatMessages([ { role: 'user', content: 'What is 2+2?', }, - ], + ]), reasoning: { enabled: true, effort: 'low', @@ -744,12 +745,12 @@ describe('callModel E2E Tests', () => { it('should successfully stream tool call deltas when tools are called', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.1-8b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "What's the weather like in Paris? Use the get_weather tool to find out.", }, - ], + ]), tools: [ { type: ToolType.Function, @@ -800,12 +801,12 @@ describe('callModel E2E Tests', () => { it('should successfully stream all response events', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'hello'.", }, - ], + ]), }); const events: EnhancedResponseStreamEvent[] = []; @@ -832,12 +833,12 @@ describe('callModel E2E Tests', () => { it('should include text delta events', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Count to 3.', }, - ], + ]), }); const textDeltaEvents: EnhancedResponseStreamEvent[] = []; @@ -865,12 +866,12 @@ describe('callModel E2E Tests', () => { it('should successfully stream in chat-compatible format', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'test'.", }, - ], + ]), }); const chunks: ChatStreamEvent[] = []; @@ -891,12 +892,12 @@ describe('callModel E2E Tests', () => { it('should return events with correct shape for each event type', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Count from 1 to 3.', }, - ], + ]), }); let hasContentDelta = false; @@ -952,12 +953,12 @@ describe('callModel E2E Tests', () => { it('should validate content.delta events have proper structure', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'hello world'.", }, - ], + ]), }); const contentDeltas: ChatStreamEvent[] = []; @@ -989,12 +990,12 @@ describe('callModel E2E Tests', () => { it('should include tool.preliminary_result events with correct shape when generator tools are executed', async () => { const response = client.callModel({ model: 'openai/gpt-4o-mini', - input: [ + input: fromChatMessages([ { role: 'user', content: 'What time is it? Use the get_time tool.', }, - ], + ]), tools: [ { type: ToolType.Function, @@ -1078,12 +1079,12 @@ describe('callModel E2E Tests', () => { it('should allow reading text and streaming simultaneously', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'concurrent test'.", }, - ], + ]), }); // Get full text and stream concurrently @@ -1113,12 +1114,12 @@ describe('callModel E2E Tests', () => { it('should allow multiple stream consumers', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Write a short sentence.', }, - ], + ]), }); // Start two concurrent stream consumers @@ -1161,12 +1162,12 @@ describe('callModel E2E Tests', () => { it('should allow sequential consumption - text then stream', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'sequential test'.", }, - ], + ]), }); // First, get the full text @@ -1191,12 +1192,12 @@ describe('callModel E2E Tests', () => { it('should allow sequential consumption - stream then text', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'reverse test'.", }, - ], + ]), }); // First, collect deltas from stream @@ -1222,12 +1223,12 @@ describe('callModel E2E Tests', () => { it('should handle invalid model gracefully', async () => { const response = client.callModel({ model: 'invalid/model-that-does-not-exist', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Test', }, - ], + ]), }); await expect(response.getText()).rejects.toThrow(); @@ -1253,12 +1254,12 @@ describe('callModel E2E Tests', () => { it('should return full response with correct shape', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'hello'.", }, - ], + ]), }); const fullResponse = await response.getResponse(); @@ -1299,12 +1300,12 @@ describe('callModel E2E Tests', () => { it('should return usage with correct shape including all token details', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'hello'.", }, - ], + ]), }); const fullResponse = await response.getResponse(); @@ -1355,12 +1356,12 @@ describe('callModel E2E Tests', () => { it('should return error and incompleteDetails fields with correct shape', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'test'.", }, - ], + ]), }); const fullResponse = await response.getResponse(); @@ -1381,12 +1382,12 @@ describe('callModel E2E Tests', () => { it('should allow concurrent access with other methods', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'test'.", }, - ], + ]), }); // Get both text and full response concurrently @@ -1409,12 +1410,12 @@ describe('callModel E2E Tests', () => { it('should return consistent results on multiple calls', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'consistent'.", }, - ], + ]), }); const firstCall = await response.getResponse(); @@ -1434,12 +1435,12 @@ describe('callModel E2E Tests', () => { it('should respect maxOutputTokens parameter', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: 'Write a long story about a cat.', }, - ], + ]), maxOutputTokens: 10, }); @@ -1453,12 +1454,12 @@ describe('callModel E2E Tests', () => { it('should work with instructions parameter', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say exactly: 'test complete'", }, - ], + ]), instructions: 'You are a helpful assistant. Keep responses concise.', }); @@ -1473,12 +1474,12 @@ describe('callModel E2E Tests', () => { it('should support provider parameter with correct shape', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'provider test'.", }, - ], + ]), provider: { allowFallbacks: true, requireParameters: false, @@ -1495,12 +1496,12 @@ describe('callModel E2E Tests', () => { it('should support provider with order preference', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'ordered provider'.", }, - ], + ]), provider: { order: [ 'Together', @@ -1520,12 +1521,12 @@ describe('callModel E2E Tests', () => { it('should support provider with ignore list', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'ignore test'.", }, - ], + ]), provider: { ignore: [ 'SomeProvider', @@ -1543,12 +1544,12 @@ describe('callModel E2E Tests', () => { it('should support provider with quantizations filter', async () => { const response = client.callModel({ model: 'meta-llama/llama-3.2-1b-instruct', - input: [ + input: fromChatMessages([ { role: 'user', content: "Say 'quantization test'.", }, - ], + ]), provider: { allowFallbacks: true, }, diff --git a/tests/unit/create-tool.test.ts b/tests/unit/create-tool.test.ts new file mode 100644 index 00000000..bb26c0cd --- /dev/null +++ b/tests/unit/create-tool.test.ts @@ -0,0 +1,185 @@ +import { describe, expect, it } from 'vitest'; +import { z } from 'zod/v4'; +import { tool } from '../../src/lib/tool.js'; +import { ToolType } from '../../src/lib/tool-types.js'; + +describe('tool', () => { + describe('tool - regular tools', () => { + it('should create a tool with the correct structure', () => { + const testTool = tool({ + name: 'test_tool', + description: 'A test tool', + inputSchema: z.object({ + input: z.string(), + }), + execute: async (params) => { + return { result: params.input }; + }, + }); + + expect(testTool.type).toBe(ToolType.Function); + expect(testTool.function.name).toBe('test_tool'); + expect(testTool.function.description).toBe('A test tool'); + expect(testTool.function.inputSchema).toBeDefined(); + }); + + it('should infer execute params from inputSchema', async () => { + const weatherTool = tool({ + name: 'weather', + inputSchema: z.object({ + location: z.string(), + units: z.enum(['celsius', 'fahrenheit']).optional(), + }), + execute: async (params) => { + // params should be typed as { location: string; units?: 'celsius' | 'fahrenheit' } + const location: string = params.location; + const units: 'celsius' | 'fahrenheit' | undefined = params.units; + return { location, units }; + }, + }); + + const result = await weatherTool.function.execute({ location: 'NYC', units: 'fahrenheit' }); + expect(result.location).toBe('NYC'); + expect(result.units).toBe('fahrenheit'); + }); + + it('should enforce output schema return type', async () => { + const tempTool = tool({ + name: 'get_temperature', + inputSchema: z.object({ + location: z.string(), + }), + outputSchema: z.object({ + temperature: z.number(), + description: z.string(), + }), + execute: async (_params) => { + // Return type should be enforced as { temperature: number; description: string } + return { + temperature: 72, + description: 'Sunny', + }; + }, + }); + + const result = await tempTool.function.execute({ location: 'NYC' }); + expect(result.temperature).toBe(72); + expect(result.description).toBe('Sunny'); + }); + + it('should support synchronous execute functions', () => { + const syncTool = tool({ + name: 'sync_tool', + inputSchema: z.object({ + a: z.number(), + b: z.number(), + }), + execute: (params) => { + return { sum: params.a + params.b }; + }, + }); + + const result = syncTool.function.execute({ a: 5, b: 3 }); + expect(result).toEqual({ sum: 8 }); + }); + + it('should pass context to execute function', async () => { + let receivedContext: unknown; + + const contextTool = tool({ + name: 'context_tool', + inputSchema: z.object({}), + execute: async (_params, context) => { + receivedContext = context; + return {}; + }, + }); + + const mockContext = { + numberOfTurns: 3, + messageHistory: [], + model: 'test-model', + }; + + await contextTool.function.execute({}, mockContext); + expect(receivedContext).toEqual(mockContext); + }); + }); + + describe('tool - generator tools (with eventSchema)', () => { + it('should create a generator tool with the correct structure', () => { + const streamingTool = tool({ + name: 'streaming_tool', + description: 'A streaming tool', + inputSchema: z.object({ + query: z.string(), + }), + eventSchema: z.object({ + progress: z.number(), + }), + outputSchema: z.object({ + result: z.string(), + }), + execute: async function* (_params) { + yield { progress: 50 }; + yield { result: 'done' }; + }, + }); + + expect(streamingTool.type).toBe(ToolType.Function); + expect(streamingTool.function.name).toBe('streaming_tool'); + expect(streamingTool.function.eventSchema).toBeDefined(); + expect(streamingTool.function.outputSchema).toBeDefined(); + }); + + it('should yield properly typed events and output', async () => { + const progressTool = tool({ + name: 'progress_tool', + inputSchema: z.object({ + data: z.string(), + }), + eventSchema: z.object({ + status: z.string(), + progress: z.number(), + }), + outputSchema: z.object({ + completed: z.boolean(), + result: z.string(), + }), + execute: async function* (params) { + yield { status: 'started', progress: 0 }; + yield { status: 'processing', progress: 50 }; + yield { completed: true, result: `Processed: ${params.data}` }; + }, + }); + + const results: unknown[] = []; + const mockContext = { numberOfTurns: 1, messageHistory: [] }; + for await (const event of progressTool.function.execute({ data: 'test' }, mockContext)) { + results.push(event); + } + + expect(results).toHaveLength(3); + expect(results[0]).toEqual({ status: 'started', progress: 0 }); + expect(results[1]).toEqual({ status: 'processing', progress: 50 }); + expect(results[2]).toEqual({ completed: true, result: 'Processed: test' }); + }); + }); + + describe('tool - manual tools (execute: false)', () => { + it('should create a manual tool without execute function', () => { + const manualTool = tool({ + name: 'manual_tool', + description: 'A manual tool', + inputSchema: z.object({ + query: z.string(), + }), + execute: false, + }); + + expect(manualTool.type).toBe(ToolType.Function); + expect(manualTool.function.name).toBe('manual_tool'); + expect(manualTool.function).not.toHaveProperty('execute'); + }); + }); +});