Skip to content

Commit cd7a21c

Browse files
committed
feat: Implement function calling
- Added function call converter with methods to detect function calls, extract text, and convert responses to XML. - Created unit tests for function call converter to ensure proper detection and conversion of various tool call formats. - Introduced tool registry to manage supported tools and their schemas, including generation of OpenAI and Anthropic schemas. - Defined schemas for new tools: apply_diff, insert_content, list_code_definition_names, list_files, read_file, search_and_replace, search_files, and write_to_file. - Enhanced API to support tool calling based on provider type. - Updated UI to include settings for enabling function calling, with appropriate translations.
1 parent 142cdb5 commit cd7a21c

31 files changed

+1392
-7
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ const baseProviderSettingsSchema = z.object({
6969
includeMaxTokens: z.boolean().optional(),
7070
diffEnabled: z.boolean().optional(),
7171
todoListEnabled: z.boolean().optional(),
72+
toolCallEnabled: z.boolean().optional(),
7273
fuzzyMatchThreshold: z.number().optional(),
7374
modelTemperature: z.number().nullish(),
7475
rateLimitSeconds: z.number().optional(),

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import {
3636
ZAiHandler,
3737
FireworksHandler,
3838
} from "./providers"
39+
import { ToolArgs } from "../core/prompts/tools/types"
3940

4041
export interface SingleCompletionHandler {
4142
completePrompt(prompt: string): Promise<string>
@@ -44,6 +45,8 @@ export interface SingleCompletionHandler {
4445
export interface ApiHandlerCreateMessageMetadata {
4546
mode?: string
4647
taskId: string
48+
tools?: string[]
49+
toolArgs?: ToolArgs
4750
}
4851

4952
export interface ApiHandler {

src/api/providers/openai.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { getModelParams } from "../transform/model-params"
2323
import { DEFAULT_HEADERS } from "./constants"
2424
import { BaseProvider } from "./base-provider"
2525
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
26+
import { getToolRegistry } from "../../core/tools/schemas/tool-registry"
2627

2728
// TODO: Rename this to OpenAICompatibleHandler. Also, I think the
2829
// `OpenAINativeHandler` can subclass from this, since it's obviously
@@ -86,6 +87,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
8687
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
8788
const ark = modelUrl.includes(".volces.com")
8889

90+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
91+
const toolRegistry = getToolRegistry()
92+
8993
if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
9094
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
9195
return
@@ -157,6 +161,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
157161
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
158162
...(reasoning && reasoning),
159163
}
164+
if (toolCallEnabled) {
165+
requestOptions.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!)
166+
}
160167

161168
// Add max_tokens if needed
162169
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
@@ -192,6 +199,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
192199
text: (delta.reasoning_content as string | undefined) || "",
193200
}
194201
}
202+
if (delta?.tool_calls) {
203+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
204+
}
195205
if (chunk.usage) {
196206
lastUsage = chunk.usage
197207
}

src/api/providers/openrouter.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import { getModelEndpoints } from "./fetchers/modelEndpointCache"
2424

2525
import { DEFAULT_HEADERS } from "./constants"
2626
import { BaseProvider } from "./base-provider"
27-
import type { SingleCompletionHandler } from "../index"
27+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
28+
import { getToolRegistry } from "../../core/tools/schemas/tool-registry"
2829

2930
// Add custom interface for OpenRouter params.
3031
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
@@ -72,10 +73,13 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
7273
override async *createMessage(
7374
systemPrompt: string,
7475
messages: Anthropic.Messages.MessageParam[],
76+
metadata?: ApiHandlerCreateMessageMetadata,
7577
): AsyncGenerator<ApiStreamChunk> {
7678
const model = await this.fetchModel()
7779

7880
let { id: modelId, maxTokens, temperature, topP, reasoning } = model
81+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
82+
const toolRegistry = getToolRegistry()
7983

8084
// OpenRouter sends reasoning tokens by default for Gemini 2.5 Pro
8185
// Preview even if you don't request them. This is not the default for
@@ -133,6 +137,9 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
133137
...(transforms && { transforms }),
134138
...(reasoning && { reasoning }),
135139
}
140+
if (toolCallEnabled) {
141+
completionParams.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs!)
142+
}
136143

137144
const stream = await this.client.chat.completions.create(completionParams)
138145

@@ -156,6 +163,11 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
156163
yield { type: "text", text: delta.content }
157164
}
158165

166+
// Handle tool calls
167+
if (delta?.tool_calls) {
168+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
169+
}
170+
159171
if (chunk.usage) {
160172
lastUsage = chunk.usage
161173
}

src/api/transform/stream.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import { ToolCallProviderType } from "../../shared/api"
2+
13
export type ApiStream = AsyncGenerator<ApiStreamChunk>
24

3-
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
5+
export type ApiStreamChunk =
6+
| ApiStreamTextChunk
7+
| ApiStreamUsageChunk
8+
| ApiStreamReasoningChunk
9+
| ApiStreamError
10+
| ApiStreamToolCallChunk
411

512
export interface ApiStreamError {
613
type: "error"
@@ -27,3 +34,9 @@ export interface ApiStreamUsageChunk {
2734
reasoningTokens?: number
2835
totalCost?: number
2936
}
37+
38+
export interface ApiStreamToolCallChunk {
39+
type: "tool_call"
40+
toolCalls: any
41+
toolCallType: ToolCallProviderType
42+
}

src/core/assistant-message/presentAssistantMessage.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ export async function presentAssistantMessage(cline: Task) {
429429
)
430430
}
431431

432-
if (isMultiFileApplyDiffEnabled) {
432+
if (isMultiFileApplyDiffEnabled || cline.apiConfiguration.toolCallEnabled === true) {
433433
await checkpointSaveAndMark(cline)
434434
await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
435435
} else {

src/core/config/ProviderSettingsManager.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const providerProfilesSchema = z.object({
3232
openAiHeadersMigrated: z.boolean().optional(),
3333
consecutiveMistakeLimitMigrated: z.boolean().optional(),
3434
todoListEnabledMigrated: z.boolean().optional(),
35+
toolcallEnabledMigrated: z.boolean().optional(),
3536
})
3637
.optional(),
3738
})
@@ -156,6 +157,11 @@ export class ProviderSettingsManager {
156157
providerProfiles.migrations.todoListEnabledMigrated = true
157158
isDirty = true
158159
}
160+
if (!providerProfiles.migrations.toolcallEnabledMigrated) {
161+
await this.migrateToolCallEnabled(providerProfiles)
162+
providerProfiles.migrations.toolcallEnabledMigrated = true
163+
isDirty = true
164+
}
159165

160166
if (isDirty) {
161167
await this.store(providerProfiles)
@@ -273,6 +279,17 @@ export class ProviderSettingsManager {
273279
console.error(`[MigrateTodoListEnabled] Failed to migrate todo list enabled setting:`, error)
274280
}
275281
}
282+
private async migrateToolCallEnabled(providerProfiles: ProviderProfiles) {
283+
try {
284+
for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
285+
if (apiConfig.toolCallEnabled === undefined) {
286+
apiConfig.toolCallEnabled = true
287+
}
288+
}
289+
} catch (error) {
290+
console.error(`[migrateToolCallEnabled] Failed to migrate tool call enabled setting:`, error)
291+
}
292+
}
276293

277294
/**
278295
* List all available configs with metadata.

src/core/prompts/tools/index.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { getNewTaskDescription } from "./new-task"
2424
import { getCodebaseSearchDescription } from "./codebase-search"
2525
import { getUpdateTodoListDescription } from "./update-todo-list"
2626
import { CodeIndexManager } from "../../../services/code-index/manager"
27+
import { getToolRegistry } from "../../tools/schemas/tool-registry"
2728

2829
// Map of tool names to their description functions
2930
const toolDescriptionMap: Record<string, (args: ToolArgs) => string | undefined> = {
@@ -118,7 +119,21 @@ export function getToolDescriptionsForMode(
118119
tools.delete("update_todo_list")
119120
}
120121

121-
// Map tool descriptions for allowed tools
122+
// If toolCallEnabled is true, skip XML tool descriptions for supported tools
123+
let supportedTools = []
124+
if (settings?.toolCallEnabled === true) {
125+
const toolRegistry = getToolRegistry()
126+
supportedTools = toolRegistry.getSupportedTools(Array.from(tools))
127+
128+
for (const tool of supportedTools) {
129+
// tools.delete(tool)
130+
toolDescriptionMap[tool] = (args) => {
131+
return `## ${tool}\n\nMUST USE ${tool} TOOL BY LLM NATIVE TOOL CALL. NOT USING XML FORMAT.`
132+
}
133+
}
134+
}
135+
136+
// Map tool descriptions for allowed tools (traditional XML mode)
122137
const descriptions = Array.from(tools).map((toolName) => {
123138
const descriptionFn = toolDescriptionMap[toolName]
124139
if (!descriptionFn) {

src/core/prompts/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
export interface SystemPromptSettings {
55
maxConcurrentFileReads: number
66
todoListEnabled: boolean
7+
toolCallEnabled?: boolean
78
useAgentRules: boolean
89
}

src/core/task/Task.ts

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import { CloudService } from "@roo-code/cloud"
3636

3737
// api
3838
import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
39+
import { getToolRegistry } from "../tools/schemas/tool-registry"
40+
import { getGroupName, getToolsForMode } from "../../shared/modes"
3941
import { ApiStream } from "../../api/transform/stream"
4042

4143
// shared
@@ -46,10 +48,10 @@ import { t } from "../../i18n"
4648
import { ClineApiReqCancelReason, ClineApiReqInfo } from "../../shared/ExtensionMessage"
4749
import { getApiMetrics } from "../../shared/getApiMetrics"
4850
import { ClineAskResponse } from "../../shared/WebviewMessage"
49-
import { defaultModeSlug } from "../../shared/modes"
50-
import { DiffStrategy } from "../../shared/tools"
51+
import { defaultModeSlug, modes, getModeBySlug } from "../../shared/modes"
52+
import { DiffStrategy, ToolUse } from "../../shared/tools"
5153
import { EXPERIMENT_IDS, experiments } from "../../shared/experiments"
52-
import { getModelMaxOutputTokens } from "../../shared/api"
54+
import { getModelMaxOutputTokens, supportToolCall } from "../../shared/api"
5355

5456
// services
5557
import { UrlContentFetcher } from "../../services/browser/UrlContentFetcher"
@@ -99,6 +101,8 @@ import { getMessagesSinceLastSummary, summarizeConversation } from "../condense"
99101
import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning"
100102
import { restoreTodoListForTask } from "../tools/updateTodoListTool"
101103
import { AutoApprovalHandler } from "./AutoApprovalHandler"
104+
import { handleOpenaiToolCall } from "./tool-call-helper"
105+
import { ToolArgs } from "../prompts/tools/types"
102106

103107
const MAX_EXPONENTIAL_BACKOFF_SECONDS = 600 // 10 minutes
104108

@@ -1568,6 +1572,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
15681572
const stream = this.attemptApiRequest()
15691573
let assistantMessage = ""
15701574
let reasoningMessage = ""
1575+
let accumulatedToolCalls: any[] = []
15711576
this.isStreaming = true
15721577

15731578
try {
@@ -1612,6 +1617,37 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
16121617
presentAssistantMessage(this)
16131618
break
16141619
}
1620+
case "tool_call": {
1621+
if (chunk.toolCallType === "openai") {
1622+
const xmlContent = handleOpenaiToolCall(accumulatedToolCalls, chunk)
1623+
if (xmlContent) {
1624+
console.log(
1625+
`[RooCode#recursivelyMakeRooRequests] OpenAI tool call XML content:\n${xmlContent}`,
1626+
)
1627+
assistantMessage += xmlContent
1628+
1629+
// Parse raw assistant message chunk into content blocks.
1630+
const prevLength = this.assistantMessageContent.length
1631+
if (this.isAssistantMessageParserEnabled && this.assistantMessageParser) {
1632+
this.assistantMessageContent =
1633+
this.assistantMessageParser.processChunk(xmlContent)
1634+
} else {
1635+
// Use the old parsing method when experiment is disabled
1636+
this.assistantMessageContent = parseAssistantMessage(assistantMessage)
1637+
}
1638+
1639+
if (this.assistantMessageContent.length > prevLength) {
1640+
// New content we need to present, reset to
1641+
// false in case previous content set this to true.
1642+
this.userMessageContentReady = false
1643+
}
1644+
1645+
// Present content to user.
1646+
presentAssistantMessage(this)
1647+
break
1648+
}
1649+
}
1650+
}
16151651
}
16161652

16171653
if (this.abort) {
@@ -1884,6 +1920,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
18841920
{
18851921
maxConcurrentFileReads: maxConcurrentFileReads ?? 5,
18861922
todoListEnabled: apiConfiguration?.todoListEnabled ?? true,
1923+
toolCallEnabled: apiConfiguration?.toolCallEnabled ?? false,
18871924
useAgentRules: vscode.workspace.getConfiguration("roo-cline").get<boolean>("useAgentRules") ?? true,
18881925
},
18891926
)
@@ -1902,6 +1939,12 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
19021939
autoCondenseContext = true,
19031940
autoCondenseContextPercent = 100,
19041941
profileThresholds = {},
1942+
browserViewportSize,
1943+
experiments,
1944+
enableMcpServerCreation,
1945+
maxConcurrentFileReads,
1946+
maxReadFileLine,
1947+
browserToolEnabled,
19051948
} = state ?? {}
19061949

19071950
// Get condensing configuration for automatic triggers.
@@ -2024,9 +2067,69 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
20242067
throw new Error("Auto-approval limit reached and user did not approve continuation")
20252068
}
20262069

2070+
// Generate tool schemas if toolCallEnabled is true
2071+
let tools: string[] | undefined = undefined
2072+
let toolArgs: ToolArgs | undefined
2073+
const apiProvider = this.apiConfiguration.apiProvider
2074+
if (this.apiConfiguration.toolCallEnabled === true && supportToolCall(apiProvider)) {
2075+
const toolRegistry = getToolRegistry()
2076+
const provider = this.providerRef.deref()
2077+
2078+
if (provider) {
2079+
const state = await provider.getState()
2080+
const modeConfig =
2081+
getModeBySlug(mode!, state.customModes) || modes.find((m) => m.slug === mode) || modes[0]
2082+
const availableTools = getToolsForMode(modeConfig.groups)
2083+
const supportedTools = toolRegistry.getSupportedTools(availableTools)
2084+
tools = supportedTools
2085+
2086+
// Determine if browser tools can be used based on model support, mode, and user settings
2087+
let modelSupportsComputerUse = false
2088+
2089+
// Create a temporary API handler to check if the model supports computer use
2090+
// This avoids relying on an active Cline instance which might not exist during preview
2091+
try {
2092+
const tempApiHandler = buildApiHandler(apiConfiguration!)
2093+
modelSupportsComputerUse = tempApiHandler.getModel().info.supportsComputerUse ?? false
2094+
} catch (error) {
2095+
console.error("Error checking if model supports computer use:", error)
2096+
}
2097+
2098+
const modeSupportsBrowser =
2099+
modeConfig?.groups.some((group) => getGroupName(group) === "browser") ?? false
2100+
2101+
// Only enable browser tools if the model supports it, the mode includes browser tools,
2102+
// and browser tools are enabled in settings
2103+
const canUseBrowserTool =
2104+
modelSupportsComputerUse && modeSupportsBrowser && (browserToolEnabled ?? true)
2105+
2106+
toolArgs = {
2107+
cwd: this.cwd,
2108+
supportsComputerUse: canUseBrowserTool,
2109+
diffStrategy: this.diffStrategy,
2110+
browserViewportSize,
2111+
mcpHub: provider.getMcpHub(),
2112+
partialReadsEnabled: maxReadFileLine !== -1,
2113+
settings: {
2114+
...{
2115+
maxConcurrentFileReads: maxConcurrentFileReads ?? 5,
2116+
todoListEnabled: apiConfiguration?.todoListEnabled ?? true,
2117+
toolCallEnabled: apiConfiguration?.toolCallEnabled ?? false,
2118+
useAgentRules:
2119+
vscode.workspace.getConfiguration("roo-cline").get<boolean>("useAgentRules") ?? true,
2120+
},
2121+
enableMcpServerCreation,
2122+
},
2123+
experiments,
2124+
}
2125+
}
2126+
}
2127+
20272128
const metadata: ApiHandlerCreateMessageMetadata = {
20282129
mode: mode,
20292130
taskId: this.taskId,
2131+
tools: tools,
2132+
toolArgs: toolArgs,
20302133
}
20312134

20322135
const stream = this.api.createMessage(systemPrompt, cleanConversationHistory, metadata)

0 commit comments

Comments
 (0)