diff --git a/src/api/providers/__tests__/vscode-lm.test.ts b/src/api/providers/__tests__/vscode-lm.test.ts index 34e0d60b1d6..ae4d71f001b 100644 --- a/src/api/providers/__tests__/vscode-lm.test.ts +++ b/src/api/providers/__tests__/vscode-lm.test.ts @@ -235,13 +235,56 @@ describe("VsCodeLmHandler", () => { // consume stream } }).rejects.toThrow("API Error") + }) + + it("should execute tasks from tool calls", async () => { + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user" as const, + content: "Execute task", + }, + ] + + const toolCallData = { + name: "taskExecutor", + arguments: { task: "exampleTask" }, + callId: "call-2", + } + + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelToolCallPart( + toolCallData.callId, + toolCallData.name, + toolCallData.arguments, + ) + return + })(), + text: (async function* () { + yield JSON.stringify({ type: "tool_call", ...toolCallData }) + return + })(), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk + expect(chunks[0]).toEqual({ + type: "text", + text: JSON.stringify({ type: "tool_call", ...toolCallData }), + }) }) }) describe("getModel", () => { it("should return model info when client exists", async () => { const mockModel = { ...mockLanguageModelChat } - ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) // Initialize client await handler["getClient"]() @@ -291,5 +334,39 @@ describe("VsCodeLmHandler", () => { "VSCode LM completion error: Completion failed", ) }) + + it("should execute tasks during completion", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + const toolCallData = { + name: "taskExecutor", + arguments: { task: "exampleTask" }, + callId: "call-3", + } + + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + yield new vscode.LanguageModelToolCallPart( + toolCallData.callId, + toolCallData.name, + toolCallData.arguments, + ) + return + })(), + text: (async function* () { + yield responseText + yield JSON.stringify({ type: "tool_call", ...toolCallData }) + return + })(), + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toContain(responseText) + expect(result).toContain(JSON.stringify({ type: "tool_call", ...toolCallData })) + expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled() + }) }) }) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index bf1215e2388..6c66429889f 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -440,6 +440,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan inputSize: JSON.stringify(chunk.input).length, }) + // Execute the tool call + const toolResult = await this.executeToolCall(toolCall) + accumulatedText += toolResult + yield { type: "text", text: toolCallText, @@ -563,6 +567,41 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan for await (const chunk of response.stream) { if (chunk instanceof vscode.LanguageModelTextPart) { result += chunk.value + } else if (chunk instanceof vscode.LanguageModelToolCallPart) { + try { + // Validate tool call parameters + if (!chunk.name || typeof chunk.name !== "string") { + console.warn("Roo Code : Invalid tool name received:", chunk.name) + continue + } + + if (!chunk.callId || typeof chunk.callId !== "string") { + console.warn("Roo Code : Invalid tool callId received:", chunk.callId) + continue + } + + // Ensure input is a valid object + if (!chunk.input || typeof chunk.input !== "object") { + console.warn("Roo Code : Invalid tool input received:", chunk.input) + continue + } + + // Convert tool calls to text format with proper error handling + const toolCall = { + type: "tool_call", + name: chunk.name, + arguments: chunk.input, + callId: chunk.callId, + } + + // Execute the tool call + const toolResult = await this.executeToolCall(toolCall) + result += JSON.stringify(toolCall) + } catch (error) { + console.error("Roo Code : Failed to process tool call:", error) + // Continue processing other chunks even if one fails + continue + } } } return result @@ -573,6 +612,13 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan throw error } } + + private async executeToolCall(toolCall: { type: string; name: string; arguments: any; callId: string }): Promise { + // Implement the logic to execute the tool call based on the tool name and arguments + // This is a placeholder implementation and should be replaced with actual tool execution logic + console.log(`Executing tool call: ${toolCall.name} with arguments: ${JSON.stringify(toolCall.arguments)}`) + return `` + } } export async function getVsCodeLmModels() {