Skip to content

Commit 3dc988a

Browse files
committed
feat: Implement streaming JSON to XML converter for tool calls
- Added a new file `tool-call-helper.ts` that contains the `StreamingToolCallProcessor` class for converting tool call JSON chunks into XML format. - Enhanced the `applyDiffTool` function to support new search and replace fields in parsed diffs. - Updated the `generateSystemPrompt` function to include tool call support based on provider configuration. - Introduced a new `supportToolCall` function to determine if a provider supports tool calls. - Added a new settings control component for enabling/disabling tool calls in the UI. - Updated localization files to include translations for the new tool call feature.
1 parent e5d93f2 commit 3dc988a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2893
-64
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const baseProviderSettingsSchema = z.object({
7878
includeMaxTokens: z.boolean().optional(),
7979
diffEnabled: z.boolean().optional(),
8080
todoListEnabled: z.boolean().optional(),
81+
toolCallEnabled: z.boolean().optional(),
8182
fuzzyMatchThreshold: z.number().optional(),
8283
modelTemperature: z.number().nullish(),
8384
rateLimitSeconds: z.number().optional(),

src/api/index.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22

3-
import type { ProviderSettings, ModelInfo } from "@roo-code/types"
3+
import type { ProviderSettings, ModelInfo, ToolName } from "@roo-code/types"
44

55
import { ApiStream } from "./transform/stream"
66

@@ -37,6 +37,7 @@ import {
3737
ZAiHandler,
3838
FireworksHandler,
3939
} from "./providers"
40+
import { ToolArgs } from "../core/prompts/tools/types"
4041

4142
export interface SingleCompletionHandler {
4243
completePrompt(prompt: string): Promise<string>
@@ -52,6 +53,14 @@ export interface ApiHandlerCreateMessageMetadata {
5253
* Used to enforce "skip once" after a condense operation.
5354
*/
5455
suppressPreviousResponseId?: boolean
56+
/**
57+
* tool call
58+
*/
59+
tools?: ToolName[]
60+
/**
61+
* tool call args
62+
*/
63+
toolArgs?: ToolArgs
5564
}
5665

5766
export interface ApiHandler {

src/api/providers/base-provider.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,34 @@ export abstract class BaseProvider implements ApiHandler {
3232

3333
return countTokens(content, { useWorker: true })
3434
}
35+
36+
/**
37+
* Convert tool schemas to text format for token counting
38+
*/
39+
protected convertToolSchemasToText(toolSchemas: Anthropic.ToolUnion[]): string {
40+
if (toolSchemas.length === 0) {
41+
return ""
42+
}
43+
44+
const toolsDescription = toolSchemas
45+
.map((tool) => {
46+
// Handle different tool types by accessing properties safely
47+
const toolName = tool.name
48+
let toolText = `Tool: ${toolName}\n`
49+
50+
// Try to access description and input_schema properties
51+
if ("description" in tool) {
52+
toolText += `Description: ${tool.description}\n`
53+
}
54+
55+
if ("input_schema" in tool && tool.input_schema && typeof tool.input_schema === "object") {
56+
toolText += `Parameters:\n${JSON.stringify(tool.input_schema, null, 2)}\n`
57+
}
58+
59+
return toolText
60+
})
61+
.join("\n---\n")
62+
63+
return `Available Tools:\n${toolsDescription}`
64+
}
3565
}

src/api/providers/lm-studio.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import { BaseProvider } from "./base-provider"
1515
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
1616
import { getModels, getModelsFromCache } from "./fetchers/modelCache"
1717
import { getApiRequestTimeout } from "./utils/timeout-config"
18+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
1819

1920
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
2021
protected options: ApiHandlerOptions
@@ -40,6 +41,8 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
4041
{ role: "system", content: systemPrompt },
4142
...convertToOpenAiMessages(messages),
4243
]
44+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
45+
const toolRegistry = getToolRegistry()
4346

4447
// -------------------------
4548
// Track token usage
@@ -68,7 +71,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
6871

6972
let inputTokens = 0
7073
try {
71-
inputTokens = await this.countTokens([{ type: "text", text: systemPrompt }, ...toContentBlocks(messages)])
74+
const inputMessages: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: systemPrompt }]
75+
if (toolCallEnabled) {
76+
const toolSchemas: Anthropic.ToolUnion[] = toolRegistry.generateAnthropicToolSchemas(
77+
metadata.tools!,
78+
metadata.toolArgs,
79+
)
80+
const toolsText = this.convertToolSchemasToText(toolSchemas)
81+
inputMessages.push({ type: "text", text: toolsText })
82+
}
83+
inputMessages.push(...toContentBlocks(messages))
84+
inputTokens = await this.countTokens(inputMessages)
7285
} catch (err) {
7386
console.error("[LmStudio] Failed to count input tokens:", err)
7487
inputTokens = 0
@@ -83,6 +96,9 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
8396
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
8497
stream: true,
8598
}
99+
if (toolCallEnabled) {
100+
params.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs)
101+
}
86102

87103
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
88104
params.draft_model = this.options.lmStudioDraftModelId
@@ -108,6 +124,9 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
108124
yield processedChunk
109125
}
110126
}
127+
if (delta?.tool_calls) {
128+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
129+
}
111130
}
112131

113132
for (const processedChunk of matcher.final()) {

src/api/providers/openai.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { DEFAULT_HEADERS } from "./constants"
2424
import { BaseProvider } from "./base-provider"
2525
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
2626
import { getApiRequestTimeout } from "./utils/timeout-config"
27+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
2728

2829
// TODO: Rename this to OpenAICompatibleHandler. Also, I think the
2930
// `OpenAINativeHandler` can subclass from this, since it's obviously
@@ -92,6 +93,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
9293
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
9394
const ark = modelUrl.includes(".volces.com")
9495

96+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
97+
const toolRegistry = getToolRegistry()
98+
9599
if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
96100
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
97101
return
@@ -163,6 +167,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
163167
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
164168
...(reasoning && reasoning),
165169
}
170+
if (toolCallEnabled) {
171+
requestOptions.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs)
172+
}
166173

167174
// Add max_tokens if needed
168175
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
@@ -198,6 +205,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
198205
text: (delta.reasoning_content as string | undefined) || "",
199206
}
200207
}
208+
if (delta?.tool_calls) {
209+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
210+
}
201211
if (chunk.usage) {
202212
lastUsage = chunk.usage
203213
}

src/api/providers/openrouter.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import { getModelEndpoints } from "./fetchers/modelEndpointCache"
2424

2525
import { DEFAULT_HEADERS } from "./constants"
2626
import { BaseProvider } from "./base-provider"
27-
import type { SingleCompletionHandler } from "../index"
27+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
28+
import { getToolRegistry } from "../../core/prompts/tools/schemas/tool-registry"
2829

2930
// Add custom interface for OpenRouter params.
3031
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
@@ -72,10 +73,13 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
7273
override async *createMessage(
7374
systemPrompt: string,
7475
messages: Anthropic.Messages.MessageParam[],
76+
metadata?: ApiHandlerCreateMessageMetadata,
7577
): AsyncGenerator<ApiStreamChunk> {
7678
const model = await this.fetchModel()
7779

7880
let { id: modelId, maxTokens, temperature, topP, reasoning } = model
81+
const toolCallEnabled = metadata?.tools && metadata.tools.length > 0
82+
const toolRegistry = getToolRegistry()
7983

8084
// OpenRouter sends reasoning tokens by default for Gemini 2.5 Pro
8185
// Preview even if you don't request them. This is not the default for
@@ -133,6 +137,9 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
133137
...(transforms && { transforms }),
134138
...(reasoning && { reasoning }),
135139
}
140+
if (toolCallEnabled) {
141+
completionParams.tools = toolRegistry.generateFunctionCallSchemas(metadata.tools!, metadata.toolArgs!)
142+
}
136143

137144
const stream = await this.client.chat.completions.create(completionParams)
138145

@@ -156,6 +163,10 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
156163
yield { type: "text", text: delta.content }
157164
}
158165

166+
if (delta?.tool_calls) {
167+
yield { type: "tool_call", toolCalls: delta.tool_calls, toolCallType: "openai" }
168+
}
169+
159170
if (chunk.usage) {
160171
lastUsage = chunk.usage
161172
}

src/api/transform/stream.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import { ToolCallProviderType } from "../../shared/api"
2+
13
export type ApiStream = AsyncGenerator<ApiStreamChunk>
24

3-
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
5+
export type ApiStreamChunk =
6+
| ApiStreamTextChunk
7+
| ApiStreamUsageChunk
8+
| ApiStreamReasoningChunk
9+
| ApiStreamError
10+
| ApiStreamToolCallChunk
411

512
export interface ApiStreamError {
613
type: "error"
@@ -27,3 +34,9 @@ export interface ApiStreamUsageChunk {
2734
reasoningTokens?: number
2835
totalCost?: number
2936
}
37+
38+
export interface ApiStreamToolCallChunk {
39+
type: "tool_call"
40+
toolCalls: any
41+
toolCallType: ToolCallProviderType
42+
}

src/core/assistant-message/presentAssistantMessage.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ export async function presentAssistantMessage(cline: Task) {
429429
)
430430
}
431431

432-
if (isMultiFileApplyDiffEnabled) {
432+
if (isMultiFileApplyDiffEnabled || cline.apiConfiguration.toolCallEnabled === true) {
433433
await checkpointSaveAndMark(cline)
434434
await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
435435
} else {

src/core/config/ProviderSettingsManager.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const providerProfilesSchema = z.object({
3232
openAiHeadersMigrated: z.boolean().optional(),
3333
consecutiveMistakeLimitMigrated: z.boolean().optional(),
3434
todoListEnabledMigrated: z.boolean().optional(),
35+
toolCallEnabledMigrated: z.boolean().optional(),
3536
})
3637
.optional(),
3738
})
@@ -56,6 +57,7 @@ export class ProviderSettingsManager {
5657
openAiHeadersMigrated: true, // Mark as migrated on fresh installs
5758
consecutiveMistakeLimitMigrated: true, // Mark as migrated on fresh installs
5859
todoListEnabledMigrated: true, // Mark as migrated on fresh installs
60+
toolCallEnabledMigrated: true, // Mark as migrated on fresh installs
5961
},
6062
}
6163

@@ -156,6 +158,11 @@ export class ProviderSettingsManager {
156158
providerProfiles.migrations.todoListEnabledMigrated = true
157159
isDirty = true
158160
}
161+
if (!providerProfiles.migrations.toolCallEnabledMigrated) {
162+
await this.migrateToolCallEnabled(providerProfiles)
163+
providerProfiles.migrations.toolCallEnabledMigrated = true
164+
isDirty = true
165+
}
159166

160167
if (isDirty) {
161168
await this.store(providerProfiles)
@@ -273,6 +280,17 @@ export class ProviderSettingsManager {
273280
console.error(`[MigrateTodoListEnabled] Failed to migrate todo list enabled setting:`, error)
274281
}
275282
}
283+
private async migrateToolCallEnabled(providerProfiles: ProviderProfiles) {
284+
try {
285+
for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
286+
if (apiConfig.toolCallEnabled === undefined) {
287+
apiConfig.toolCallEnabled = false
288+
}
289+
}
290+
} catch (error) {
291+
console.error(`[migrateToolCallEnabled] Failed to migrate tool call enabled setting:`, error)
292+
}
293+
}
276294

277295
/**
278296
* List all available configs with metadata.

src/core/config/__tests__/ProviderSettingsManager.spec.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ describe("ProviderSettingsManager", () => {
6868
openAiHeadersMigrated: true,
6969
consecutiveMistakeLimitMigrated: true,
7070
todoListEnabledMigrated: true,
71+
toolCallEnabledMigrated: true,
7172
},
7273
}),
7374
)

0 commit comments

Comments
 (0)