Skip to content

Commit 6e667fc

Browse files
committed
feat: support tool calls
- Added StreamingToolCallProcessor to handle tool call data and convert JSON chunks to XML format. - Introduced ToolCallProcessingState to manage the parsing state and track progress. - Enhanced applyDiffTool to support new tool call functionality based on provider state. - Updated generateSystemPrompt to include tool call configuration. - Added supportToolCall utility to determine if a provider supports tool calls. - Created ToolCallSettingsControl component for user interface settings. - Localized new tool call settings in multiple languages.
1 parent 185365a commit 6e667fc

File tree

67 files changed

+3017
-127
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+3017
-127
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const baseProviderSettingsSchema = z.object({
7878
includeMaxTokens: z.boolean().optional(),
7979
diffEnabled: z.boolean().optional(),
8080
todoListEnabled: z.boolean().optional(),
81+
toolCallEnabled: z.boolean().optional(),
8182
fuzzyMatchThreshold: z.number().optional(),
8283
modelTemperature: z.number().nullish(),
8384
rateLimitSeconds: z.number().optional(),

src/api/index.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22

3-
import type { ProviderSettings, ModelInfo } from "@roo-code/types"
3+
import type { ProviderSettings, ModelInfo, ToolName } from "@roo-code/types"
44

55
import { ApiStream } from "./transform/stream"
66

@@ -37,6 +37,7 @@ import {
3737
FireworksHandler,
3838
} from "./providers"
3939
import { NativeOllamaHandler } from "./providers/native-ollama"
40+
import { ToolArgs } from "../core/prompts/tools/types"
4041

4142
export interface SingleCompletionHandler {
4243
completePrompt(prompt: string): Promise<string>
@@ -52,6 +53,14 @@ export interface ApiHandlerCreateMessageMetadata {
5253
* Used to enforce "skip once" after a condense operation.
5354
*/
5455
suppressPreviousResponseId?: boolean
56+
/**
57+
* tool call
58+
*/
59+
tools?: ToolName[]
60+
/**
61+
* tool call args
62+
*/
63+
toolArgs?: ToolArgs
5564
}
5665

5766
export interface ApiHandler {

src/api/providers/base-provider.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,34 @@ export abstract class BaseProvider implements ApiHandler {
3232

3333
return countTokens(content, { useWorker: true })
3434
}
35+
36+
/**
37+
* Convert tool schemas to text format for token counting
38+
*/
39+
protected convertToolSchemasToText(toolSchemas: Anthropic.ToolUnion[]): string {
40+
if (toolSchemas.length === 0) {
41+
return ""
42+
}
43+
44+
const toolsDescription = toolSchemas
45+
.map((tool) => {
46+
// Handle different tool types by accessing properties safely
47+
const toolName = tool.name
48+
let toolText = `Tool: ${toolName}\n`
49+
50+
// Try to access description and input_schema properties
51+
if ("description" in tool) {
52+
toolText += `Description: ${tool.description}\n`
53+
}
54+
55+
if ("input_schema" in tool && tool.input_schema && typeof tool.input_schema === "object") {
56+
toolText += `Parameters:\n${JSON.stringify(tool.input_schema, null, 2)}\n`
57+
}
58+
59+
return toolText
60+
})
61+
.join("\n---\n")
62+
63+
return `Available Tools:\n${toolsDescription}`
64+
}
3565
}

src/api/providers/lm-studio.ts

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import { BaseProvider } from "./base-provider"
1515
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
1616
import { getModels, getModelsFromCache } from "./fetchers/modelCache"
1717
import { getApiRequestTimeout } from "./utils/timeout-config"
18+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
1819

1920
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
2021
protected options: ApiHandlerOptions
@@ -40,6 +41,8 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
4041
{ role: "system", content: systemPrompt },
4142
...convertToOpenAiMessages(messages),
4243
]
44+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
45+
const toolRegistry = getToolRegistry()
4346

4447
// -------------------------
4548
// Track token usage
@@ -68,7 +71,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
6871

6972
let inputTokens = 0
7073
try {
71-
inputTokens = await this.countTokens([{ type: "text", text: systemPrompt }, ...toContentBlocks(messages)])
74+
const inputMessages: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: systemPrompt }]
75+
if (toolCallEnabled) {
76+
const toolSchemas: Anthropic.ToolUnion[] = toolRegistry.generateAnthropicToolSchemas(
77+
metadata.tools!,
78+
metadata.toolArgs,
79+
)
80+
const toolsText = this.convertToolSchemasToText(toolSchemas)
81+
inputMessages.push({ type: "text", text: toolsText })
82+
}
83+
inputMessages.push(...toContentBlocks(messages))
84+
inputTokens = await this.countTokens(inputMessages)
7285
} catch (err) {
7386
console.error("[LmStudio] Failed to count input tokens:", err)
7487
inputTokens = 0
@@ -83,6 +96,10 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
8396
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
8497
stream: true,
8598
}
99+
if (toolCallEnabled) {
100+
params.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs)
101+
params.tool_choice = "auto"
102+
}
86103

87104
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
88105
params.draft_model = this.options.lmStudioDraftModelId
@@ -108,6 +125,9 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
108125
yield processedChunk
109126
}
110127
}
128+
if (delta?.tool_calls) {
129+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
130+
}
111131
}
112132

113133
for (const processedChunk of matcher.final()) {

src/api/providers/openai.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { DEFAULT_HEADERS } from "./constants"
2424
import { BaseProvider } from "./base-provider"
2525
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
2626
import { getApiRequestTimeout } from "./utils/timeout-config"
27+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
2728

2829
// TODO: Rename this to OpenAICompatibleHandler. Also, I think the
2930
// `OpenAINativeHandler` can subclass from this, since it's obviously
@@ -92,6 +93,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
9293
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
9394
const ark = modelUrl.includes(".volces.com")
9495

96+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
97+
const toolRegistry = getToolRegistry()
98+
9599
if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
96100
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
97101
return
@@ -163,6 +167,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
163167
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
164168
...(reasoning && reasoning),
165169
}
170+
if (toolCallEnabled) {
171+
requestOptions.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs)
172+
requestOptions.tool_choice = "auto"
173+
}
166174

167175
// Add max_tokens if needed
168176
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
@@ -198,6 +206,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
198206
text: (delta.reasoning_content as string | undefined) || "",
199207
}
200208
}
209+
if (delta?.tool_calls) {
210+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
211+
}
201212
if (chunk.usage) {
202213
lastUsage = chunk.usage
203214
}

src/api/providers/openrouter.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import { getModelEndpoints } from "./fetchers/modelEndpointCache"
2424

2525
import { DEFAULT_HEADERS } from "./constants"
2626
import { BaseProvider } from "./base-provider"
27-
import type { SingleCompletionHandler } from "../index"
27+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
28+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
2829

2930
// Add custom interface for OpenRouter params.
3031
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
@@ -72,10 +73,13 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
7273
override async *createMessage(
7374
systemPrompt: string,
7475
messages: Anthropic.Messages.MessageParam[],
76+
metadata?: ApiHandlerCreateMessageMetadata,
7577
): AsyncGenerator<ApiStreamChunk> {
7678
const model = await this.fetchModel()
7779

7880
let { id: modelId, maxTokens, temperature, topP, reasoning } = model
81+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
82+
const toolRegistry = getToolRegistry()
7983

8084
// OpenRouter sends reasoning tokens by default for Gemini 2.5 Pro
8185
// Preview even if you don't request them. This is not the default for
@@ -133,6 +137,10 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
133137
...(transforms && { transforms }),
134138
...(reasoning && { reasoning }),
135139
}
140+
if (toolCallEnabled) {
141+
completionParams.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs!)
142+
completionParams.tool_choice = "auto"
143+
}
136144

137145
const stream = await this.client.chat.completions.create(completionParams)
138146

@@ -156,6 +164,10 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
156164
yield { type: "text", text: delta.content }
157165
}
158166

167+
if (delta?.tool_calls) {
168+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
169+
}
170+
159171
if (chunk.usage) {
160172
lastUsage = chunk.usage
161173
}

src/api/transform/stream.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import { ToolCallProviderType } from "../../shared/tools"
2+
13
export type ApiStream = AsyncGenerator<ApiStreamChunk>
24

3-
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
5+
export type ApiStreamChunk =
6+
| ApiStreamTextChunk
7+
| ApiStreamUsageChunk
8+
| ApiStreamReasoningChunk
9+
| ApiStreamError
10+
| ApiStreamToolCallChunk
411

512
export interface ApiStreamError {
613
type: "error"
@@ -27,3 +34,9 @@ export interface ApiStreamUsageChunk {
2734
reasoningTokens?: number
2835
totalCost?: number
2936
}
37+
38+
export interface ApiStreamToolCallChunk {
39+
type: "tool_call"
40+
toolCalls: any
41+
toolCallType: ToolCallProviderType
42+
}

src/core/assistant-message/AssistantMessageParser.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { type ToolName, toolNames } from "@roo-code/types"
22
import { TextContent, ToolUse, ToolParamName, toolParamNames } from "../../shared/tools"
33
import { AssistantMessageContent } from "./parseAssistantMessage"
4+
import { ToolCallParam } from "../task/tool-call-helper"
45

56
/**
67
* Parser for assistant messages. Maintains state between chunks
@@ -51,7 +52,7 @@ export class AssistantMessageParser {
5152
* Process a new chunk of text and update the parser state.
5253
* @param chunk The new chunk of text to process.
5354
*/
54-
public processChunk(chunk: string): AssistantMessageContent[] {
55+
public processChunk(chunk: string, toolCallParam?: ToolCallParam): AssistantMessageContent[] {
5556
if (this.accumulator.length + chunk.length > this.MAX_ACCUMULATOR_SIZE) {
5657
throw new Error("Assistant message exceeds maximum allowed size")
5758
}
@@ -174,6 +175,11 @@ export class AssistantMessageParser {
174175
name: extractedToolName as ToolName,
175176
params: {},
176177
partial: true,
178+
toolUseId: toolCallParam && toolCallParam.toolUserId ? toolCallParam.toolUserId : undefined,
179+
toolUseParam:
180+
toolCallParam && toolCallParam?.anthropicContent
181+
? toolCallParam?.anthropicContent
182+
: undefined,
177183
}
178184

179185
this.currentToolUseStartIndex = this.accumulator.length

src/core/assistant-message/parseAssistantMessage.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import { type ToolName, toolNames } from "@roo-code/types"
22

33
import { TextContent, ToolUse, ToolParamName, toolParamNames } from "../../shared/tools"
4+
import { ToolCallParam } from "../task/tool-call-helper"
45

56
export type AssistantMessageContent = TextContent | ToolUse
67

7-
export function parseAssistantMessage(assistantMessage: string): AssistantMessageContent[] {
8+
export function parseAssistantMessage(
9+
assistantMessage: string,
10+
toolCallParam?: ToolCallParam,
11+
): AssistantMessageContent[] {
812
let contentBlocks: AssistantMessageContent[] = []
913
let currentTextContent: TextContent | undefined = undefined
1014
let currentTextContentStartIndex = 0
@@ -103,6 +107,9 @@ export function parseAssistantMessage(assistantMessage: string): AssistantMessag
103107
name: toolUseOpeningTag.slice(1, -1) as ToolName,
104108
params: {},
105109
partial: true,
110+
toolUseId: toolCallParam && toolCallParam.toolUserId ? toolCallParam.toolUserId : undefined,
111+
toolUseParam:
112+
toolCallParam && toolCallParam?.anthropicContent ? toolCallParam?.anthropicContent : undefined,
106113
}
107114

108115
currentToolUseStartIndex = accumulator.length

src/core/assistant-message/presentAssistantMessage.ts

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import { Task } from "../task/Task"
3333
import { codebaseSearchTool } from "../tools/codebaseSearchTool"
3434
import { experiments, EXPERIMENT_IDS } from "../../shared/experiments"
3535
import { applyDiffToolLegacy } from "../tools/applyDiffTool"
36+
import Anthropic from "@anthropic-ai/sdk"
3637

3738
/**
3839
* Processes and presents assistant message content to the user interface.
@@ -61,6 +62,7 @@ export async function presentAssistantMessage(cline: Task) {
6162
return
6263
}
6364

65+
const toolCallEnabled = cline.apiConfiguration?.toolCallEnabled
6466
cline.presentAssistantMessageLocked = true
6567
cline.presentAssistantMessageHasPendingUpdates = false
6668

@@ -245,12 +247,28 @@ export async function presentAssistantMessage(cline: Task) {
245247
}
246248

247249
const pushToolResult = (content: ToolResponse) => {
248-
cline.userMessageContent.push({ type: "text", text: `${toolDescription()} Result:` })
249-
250+
const newUserMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] = [
251+
{ type: "text", text: `${toolDescription()} Result:` },
252+
]
250253
if (typeof content === "string") {
251-
cline.userMessageContent.push({ type: "text", text: content || "(tool did not return anything)" })
254+
newUserMessages.push({ type: "text", text: content || "(tool did not return anything)" })
255+
} else {
256+
newUserMessages.push(...content)
257+
}
258+
259+
if (toolCallEnabled) {
260+
const lastToolUseMessage = cline.assistantMessageContent.find((msg) => msg.type === "tool_use")
261+
if (lastToolUseMessage && lastToolUseMessage.toolUseId) {
262+
const toolUseId = lastToolUseMessage.toolUseId
263+
const toolMessage: Anthropic.ToolResultBlockParam = {
264+
tool_use_id: toolUseId,
265+
type: "tool_result",
266+
content: newUserMessages,
267+
}
268+
cline.userMessageContent.push(toolMessage)
269+
}
252270
} else {
253-
cline.userMessageContent.push(...content)
271+
cline.userMessageContent.push(...newUserMessages)
254272
}
255273

256274
// Once a tool result has been collected, ignore all other tool
@@ -429,7 +447,7 @@ export async function presentAssistantMessage(cline: Task) {
429447
)
430448
}
431449

432-
if (isMultiFileApplyDiffEnabled) {
450+
if (isMultiFileApplyDiffEnabled || toolCallEnabled) {
433451
await checkpointSaveAndMark(cline)
434452
await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
435453
} else {

0 commit comments

Comments
 (0)