Skip to content

Commit 34c524f

Browse files
daniel-lxsmrubens
andauthored
feat(chutes): detect native tool support from API supported_features (#9715)
Co-authored-by: Matt Rubens <[email protected]>
1 parent bf3e4d8 commit 34c524f

File tree

5 files changed

+331
-3
lines changed

5 files changed

+331
-3
lines changed

src/api/providers/__tests__/chutes.spec.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,83 @@ describe("ChutesHandler", () => {
233233
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
234234
})
235235

236+
it("createMessage should yield tool_call_partial from stream", async () => {
237+
mockCreate.mockImplementationOnce(() => {
238+
return {
239+
[Symbol.asyncIterator]: () => ({
240+
next: vi
241+
.fn()
242+
.mockResolvedValueOnce({
243+
done: false,
244+
value: {
245+
choices: [
246+
{
247+
delta: {
248+
tool_calls: [
249+
{
250+
index: 0,
251+
id: "call_123",
252+
function: { name: "test_tool", arguments: '{"arg":"value"}' },
253+
},
254+
],
255+
},
256+
},
257+
],
258+
},
259+
})
260+
.mockResolvedValueOnce({ done: true }),
261+
}),
262+
}
263+
})
264+
265+
const stream = handler.createMessage("system prompt", [])
266+
const firstChunk = await stream.next()
267+
268+
expect(firstChunk.done).toBe(false)
269+
expect(firstChunk.value).toEqual({
270+
type: "tool_call_partial",
271+
index: 0,
272+
id: "call_123",
273+
name: "test_tool",
274+
arguments: '{"arg":"value"}',
275+
})
276+
})
277+
278+
it("createMessage should pass tools and tool_choice to API", async () => {
279+
const tools = [
280+
{
281+
type: "function" as const,
282+
function: {
283+
name: "test_tool",
284+
description: "A test tool",
285+
parameters: { type: "object", properties: {} },
286+
},
287+
},
288+
]
289+
const tool_choice = "auto" as const
290+
291+
mockCreate.mockImplementationOnce(() => {
292+
return {
293+
[Symbol.asyncIterator]: () => ({
294+
next: vi.fn().mockResolvedValueOnce({ done: true }),
295+
}),
296+
}
297+
})
298+
299+
const stream = handler.createMessage("system prompt", [], { tools, tool_choice, taskId: "test-task-id" })
300+
// Consume stream
301+
for await (const _ of stream) {
302+
// noop
303+
}
304+
305+
expect(mockCreate).toHaveBeenCalledWith(
306+
expect.objectContaining({
307+
tools,
308+
tool_choice,
309+
}),
310+
)
311+
})
312+
236313
it("should apply DeepSeek default temperature for R1 models", () => {
237314
const testModelId = "deepseek-ai/DeepSeek-R1"
238315
const handlerWithModel = new ChutesHandler({

src/api/providers/chutes.ts

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
2828
private getCompletionParams(
2929
systemPrompt: string,
3030
messages: Anthropic.Messages.MessageParam[],
31+
metadata?: ApiHandlerCreateMessageMetadata,
3132
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
3233
const { id: model, info } = this.getModel()
3334

@@ -46,6 +47,8 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
4647
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
4748
stream: true,
4849
stream_options: { include_usage: true },
50+
...(metadata?.tools && { tools: metadata.tools }),
51+
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
4952
}
5053

5154
// Only add temperature if model supports it
@@ -65,7 +68,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
6568

6669
if (model.id.includes("DeepSeek-R1")) {
6770
const stream = await this.client.chat.completions.create({
68-
...this.getCompletionParams(systemPrompt, messages),
71+
...this.getCompletionParams(systemPrompt, messages, metadata),
6972
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
7073
})
7174

@@ -87,6 +90,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
8790
}
8891
}
8992

93+
// Emit raw tool call chunks - NativeToolCallParser handles state management
94+
if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
95+
for (const toolCall of delta.tool_calls) {
96+
yield {
97+
type: "tool_call_partial",
98+
index: toolCall.index,
99+
id: toolCall.id,
100+
name: toolCall.function?.name,
101+
arguments: toolCall.function?.arguments,
102+
}
103+
}
104+
}
105+
90106
if (chunk.usage) {
91107
yield {
92108
type: "usage",
@@ -102,7 +118,9 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
102118
}
103119
} else {
104120
// For non-DeepSeek-R1 models, use standard OpenAI streaming
105-
const stream = await this.client.chat.completions.create(this.getCompletionParams(systemPrompt, messages))
121+
const stream = await this.client.chat.completions.create(
122+
this.getCompletionParams(systemPrompt, messages, metadata),
123+
)
106124

107125
for await (const chunk of stream) {
108126
const delta = chunk.choices[0]?.delta
@@ -115,6 +133,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
115133
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
116134
}
117135

136+
// Emit raw tool call chunks - NativeToolCallParser handles state management
137+
if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
138+
for (const toolCall of delta.tool_calls) {
139+
yield {
140+
type: "tool_call_partial",
141+
index: toolCall.index,
142+
id: toolCall.id,
143+
name: toolCall.function?.name,
144+
arguments: toolCall.function?.arguments,
145+
}
146+
}
147+
}
148+
118149
if (chunk.usage) {
119150
yield {
120151
type: "usage",
@@ -166,6 +197,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
166197
override getModel() {
167198
const model = super.getModel()
168199
const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
200+
169201
return {
170202
...model,
171203
info: {
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Mocks must come first, before imports
2+
vi.mock("axios")
3+
4+
import type { Mock } from "vitest"
5+
import type { ModelInfo } from "@roo-code/types"
6+
import axios from "axios"
7+
import { getChutesModels } from "../chutes"
8+
import { chutesModels } from "@roo-code/types"
9+
10+
const mockedAxios = axios as typeof axios & {
11+
get: Mock
12+
}
13+
14+
describe("getChutesModels", () => {
15+
beforeEach(() => {
16+
vi.clearAllMocks()
17+
})
18+
19+
it("should fetch and parse models successfully", async () => {
20+
const mockResponse = {
21+
data: {
22+
data: [
23+
{
24+
id: "test/new-model",
25+
object: "model",
26+
owned_by: "test",
27+
created: 1234567890,
28+
context_length: 128000,
29+
max_model_len: 8192,
30+
input_modalities: ["text"],
31+
},
32+
],
33+
},
34+
}
35+
36+
mockedAxios.get.mockResolvedValue(mockResponse)
37+
38+
const models = await getChutesModels("test-api-key")
39+
40+
expect(mockedAxios.get).toHaveBeenCalledWith(
41+
"https://llm.chutes.ai/v1/models",
42+
expect.objectContaining({
43+
headers: expect.objectContaining({
44+
Authorization: "Bearer test-api-key",
45+
}),
46+
}),
47+
)
48+
49+
expect(models["test/new-model"]).toEqual({
50+
maxTokens: 8192,
51+
contextWindow: 128000,
52+
supportsImages: false,
53+
supportsPromptCache: false,
54+
supportsNativeTools: false,
55+
inputPrice: 0,
56+
outputPrice: 0,
57+
description: "Chutes AI model: test/new-model",
58+
})
59+
})
60+
61+
it("should override hardcoded models with dynamic API data", async () => {
62+
// Find any hardcoded model
63+
const [modelId] = Object.entries(chutesModels)[0]
64+
65+
const mockResponse = {
66+
data: {
67+
data: [
68+
{
69+
id: modelId,
70+
object: "model",
71+
owned_by: "test",
72+
created: 1234567890,
73+
context_length: 200000, // Different from hardcoded
74+
max_model_len: 10000, // Different from hardcoded
75+
input_modalities: ["text", "image"],
76+
},
77+
],
78+
},
79+
}
80+
81+
mockedAxios.get.mockResolvedValue(mockResponse)
82+
83+
const models = await getChutesModels("test-api-key")
84+
85+
// Dynamic values should override hardcoded
86+
expect(models[modelId]).toBeDefined()
87+
expect(models[modelId].contextWindow).toBe(200000)
88+
expect(models[modelId].maxTokens).toBe(10000)
89+
expect(models[modelId].supportsImages).toBe(true)
90+
})
91+
92+
it("should return hardcoded models when API returns empty", async () => {
93+
const mockResponse = {
94+
data: {
95+
data: [],
96+
},
97+
}
98+
99+
mockedAxios.get.mockResolvedValue(mockResponse)
100+
101+
const models = await getChutesModels("test-api-key")
102+
103+
// Should still have hardcoded models
104+
expect(Object.keys(models).length).toBeGreaterThan(0)
105+
expect(models).toEqual(expect.objectContaining(chutesModels))
106+
})
107+
108+
it("should return hardcoded models on API error", async () => {
109+
mockedAxios.get.mockRejectedValue(new Error("Network error"))
110+
111+
const models = await getChutesModels("test-api-key")
112+
113+
// Should still have hardcoded models
114+
expect(Object.keys(models).length).toBeGreaterThan(0)
115+
expect(models).toEqual(chutesModels)
116+
})
117+
118+
it("should work without API key", async () => {
119+
const mockResponse = {
120+
data: {
121+
data: [],
122+
},
123+
}
124+
125+
mockedAxios.get.mockResolvedValue(mockResponse)
126+
127+
const models = await getChutesModels()
128+
129+
expect(mockedAxios.get).toHaveBeenCalledWith(
130+
"https://llm.chutes.ai/v1/models",
131+
expect.objectContaining({
132+
headers: expect.not.objectContaining({
133+
Authorization: expect.anything(),
134+
}),
135+
}),
136+
)
137+
138+
expect(Object.keys(models).length).toBeGreaterThan(0)
139+
})
140+
141+
it("should detect image support from input_modalities", async () => {
142+
const mockResponse = {
143+
data: {
144+
data: [
145+
{
146+
id: "test/image-model",
147+
object: "model",
148+
owned_by: "test",
149+
created: 1234567890,
150+
context_length: 128000,
151+
max_model_len: 8192,
152+
input_modalities: ["text", "image"],
153+
},
154+
],
155+
},
156+
}
157+
158+
mockedAxios.get.mockResolvedValue(mockResponse)
159+
160+
const models = await getChutesModels("test-api-key")
161+
162+
expect(models["test/image-model"].supportsImages).toBe(true)
163+
})
164+
165+
it("should detect native tool support from supported_features", async () => {
166+
const mockResponse = {
167+
data: {
168+
data: [
169+
{
170+
id: "test/tools-model",
171+
object: "model",
172+
owned_by: "test",
173+
created: 1234567890,
174+
context_length: 128000,
175+
max_model_len: 8192,
176+
input_modalities: ["text"],
177+
supported_features: ["json_mode", "tools", "reasoning"],
178+
},
179+
],
180+
},
181+
}
182+
183+
mockedAxios.get.mockResolvedValue(mockResponse)
184+
185+
const models = await getChutesModels("test-api-key")
186+
187+
expect(models["test/tools-model"].supportsNativeTools).toBe(true)
188+
})
189+
190+
it("should not enable native tool support when tools is not in supported_features", async () => {
191+
const mockResponse = {
192+
data: {
193+
data: [
194+
{
195+
id: "test/no-tools-model",
196+
object: "model",
197+
owned_by: "test",
198+
created: 1234567890,
199+
context_length: 128000,
200+
max_model_len: 8192,
201+
input_modalities: ["text"],
202+
supported_features: ["json_mode", "reasoning"],
203+
},
204+
],
205+
},
206+
}
207+
208+
mockedAxios.get.mockResolvedValue(mockResponse)
209+
210+
const models = await getChutesModels("test-api-key")
211+
212+
expect(models["test/no-tools-model"].supportsNativeTools).toBe(false)
213+
expect(models["test/no-tools-model"].defaultToolProtocol).toBeUndefined()
214+
})
215+
})

0 commit comments

Comments
 (0)