Skip to content

Commit be76594

Browse files
mrubensroomote
andauthored
Support tool calling in native ollama provider (#9696)
Co-authored-by: Roo Code <[email protected]>
1 parent faa6c40 commit be76594

File tree

8 files changed

+336
-435
lines changed

8 files changed

+336
-435
lines changed

packages/types/src/providers/ollama.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export const ollamaDefaultModelInfo: ModelInfo = {
88
contextWindow: 200_000,
99
supportsImages: true,
1010
supportsPromptCache: true,
11+
supportsNativeTools: true,
1112
inputPrice: 0,
1213
outputPrice: 0,
1314
cacheWritesPrice: 0,

src/api/providers/__tests__/native-ollama.spec.ts

Lines changed: 270 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import { NativeOllamaHandler } from "../native-ollama"
44
import { ApiHandlerOptions } from "../../../shared/api"
5+
import { getOllamaModels } from "../fetchers/ollama"
56

67
// Mock the ollama package
78
const mockChat = vitest.fn()
@@ -16,22 +17,27 @@ vitest.mock("ollama", () => {
1617

1718
// Mock the getOllamaModels function
1819
vitest.mock("../fetchers/ollama", () => ({
19-
getOllamaModels: vitest.fn().mockResolvedValue({
20-
llama2: {
21-
contextWindow: 4096,
22-
maxTokens: 4096,
23-
supportsImages: false,
24-
supportsPromptCache: false,
25-
},
26-
}),
20+
getOllamaModels: vitest.fn(),
2721
}))
2822

23+
const mockGetOllamaModels = vitest.mocked(getOllamaModels)
24+
2925
describe("NativeOllamaHandler", () => {
3026
let handler: NativeOllamaHandler
3127

3228
beforeEach(() => {
3329
vitest.clearAllMocks()
3430

31+
// Default mock for getOllamaModels
32+
mockGetOllamaModels.mockResolvedValue({
33+
llama2: {
34+
contextWindow: 4096,
35+
maxTokens: 4096,
36+
supportsImages: false,
37+
supportsPromptCache: false,
38+
},
39+
})
40+
3541
const options: ApiHandlerOptions = {
3642
apiModelId: "llama2",
3743
ollamaModelId: "llama2",
@@ -257,4 +263,260 @@ describe("NativeOllamaHandler", () => {
257263
expect(model.info).toBeDefined()
258264
})
259265
})
266+
267+
describe("tool calling", () => {
268+
it("should include tools when model supports native tools", async () => {
269+
// Mock model with native tool support
270+
mockGetOllamaModels.mockResolvedValue({
271+
"llama3.2": {
272+
contextWindow: 128000,
273+
maxTokens: 4096,
274+
supportsImages: true,
275+
supportsPromptCache: false,
276+
supportsNativeTools: true,
277+
},
278+
})
279+
280+
const options: ApiHandlerOptions = {
281+
apiModelId: "llama3.2",
282+
ollamaModelId: "llama3.2",
283+
ollamaBaseUrl: "http://localhost:11434",
284+
}
285+
286+
handler = new NativeOllamaHandler(options)
287+
288+
// Mock the chat response
289+
mockChat.mockImplementation(async function* () {
290+
yield { message: { content: "I will use the tool" } }
291+
})
292+
293+
const tools = [
294+
{
295+
type: "function" as const,
296+
function: {
297+
name: "get_weather",
298+
description: "Get the weather for a location",
299+
parameters: {
300+
type: "object",
301+
properties: {
302+
location: { type: "string", description: "The city name" },
303+
},
304+
required: ["location"],
305+
},
306+
},
307+
},
308+
]
309+
310+
const stream = handler.createMessage(
311+
"System",
312+
[{ role: "user" as const, content: "What's the weather?" }],
313+
{ taskId: "test", tools },
314+
)
315+
316+
// Consume the stream
317+
for await (const _ of stream) {
318+
// consume stream
319+
}
320+
321+
// Verify tools were passed to the API
322+
expect(mockChat).toHaveBeenCalledWith(
323+
expect.objectContaining({
324+
tools: [
325+
{
326+
type: "function",
327+
function: {
328+
name: "get_weather",
329+
description: "Get the weather for a location",
330+
parameters: {
331+
type: "object",
332+
properties: {
333+
location: { type: "string", description: "The city name" },
334+
},
335+
required: ["location"],
336+
},
337+
},
338+
},
339+
],
340+
}),
341+
)
342+
})
343+
344+
it("should not include tools when model does not support native tools", async () => {
345+
// Mock model without native tool support
346+
mockGetOllamaModels.mockResolvedValue({
347+
llama2: {
348+
contextWindow: 4096,
349+
maxTokens: 4096,
350+
supportsImages: false,
351+
supportsPromptCache: false,
352+
supportsNativeTools: false,
353+
},
354+
})
355+
356+
// Mock the chat response
357+
mockChat.mockImplementation(async function* () {
358+
yield { message: { content: "Response without tools" } }
359+
})
360+
361+
const tools = [
362+
{
363+
type: "function" as const,
364+
function: {
365+
name: "get_weather",
366+
description: "Get the weather",
367+
parameters: { type: "object", properties: {} },
368+
},
369+
},
370+
]
371+
372+
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
373+
taskId: "test",
374+
tools,
375+
})
376+
377+
// Consume the stream
378+
for await (const _ of stream) {
379+
// consume stream
380+
}
381+
382+
// Verify tools were NOT passed
383+
expect(mockChat).toHaveBeenCalledWith(
384+
expect.not.objectContaining({
385+
tools: expect.anything(),
386+
}),
387+
)
388+
})
389+
390+
it("should not include tools when toolProtocol is xml", async () => {
391+
// Mock model with native tool support
392+
mockGetOllamaModels.mockResolvedValue({
393+
"llama3.2": {
394+
contextWindow: 128000,
395+
maxTokens: 4096,
396+
supportsImages: true,
397+
supportsPromptCache: false,
398+
supportsNativeTools: true,
399+
},
400+
})
401+
402+
const options: ApiHandlerOptions = {
403+
apiModelId: "llama3.2",
404+
ollamaModelId: "llama3.2",
405+
ollamaBaseUrl: "http://localhost:11434",
406+
}
407+
408+
handler = new NativeOllamaHandler(options)
409+
410+
// Mock the chat response
411+
mockChat.mockImplementation(async function* () {
412+
yield { message: { content: "Response" } }
413+
})
414+
415+
const tools = [
416+
{
417+
type: "function" as const,
418+
function: {
419+
name: "get_weather",
420+
description: "Get the weather",
421+
parameters: { type: "object", properties: {} },
422+
},
423+
},
424+
]
425+
426+
const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
427+
taskId: "test",
428+
tools,
429+
toolProtocol: "xml",
430+
})
431+
432+
// Consume the stream
433+
for await (const _ of stream) {
434+
// consume stream
435+
}
436+
437+
// Verify tools were NOT passed (XML protocol forces XML format)
438+
expect(mockChat).toHaveBeenCalledWith(
439+
expect.not.objectContaining({
440+
tools: expect.anything(),
441+
}),
442+
)
443+
})
444+
445+
it("should yield tool_call_partial when model returns tool calls", async () => {
446+
// Mock model with native tool support
447+
mockGetOllamaModels.mockResolvedValue({
448+
"llama3.2": {
449+
contextWindow: 128000,
450+
maxTokens: 4096,
451+
supportsImages: true,
452+
supportsPromptCache: false,
453+
supportsNativeTools: true,
454+
},
455+
})
456+
457+
const options: ApiHandlerOptions = {
458+
apiModelId: "llama3.2",
459+
ollamaModelId: "llama3.2",
460+
ollamaBaseUrl: "http://localhost:11434",
461+
}
462+
463+
handler = new NativeOllamaHandler(options)
464+
465+
// Mock the chat response with tool calls
466+
mockChat.mockImplementation(async function* () {
467+
yield {
468+
message: {
469+
content: "",
470+
tool_calls: [
471+
{
472+
function: {
473+
name: "get_weather",
474+
arguments: { location: "San Francisco" },
475+
},
476+
},
477+
],
478+
},
479+
}
480+
})
481+
482+
const tools = [
483+
{
484+
type: "function" as const,
485+
function: {
486+
name: "get_weather",
487+
description: "Get the weather for a location",
488+
parameters: {
489+
type: "object",
490+
properties: {
491+
location: { type: "string" },
492+
},
493+
required: ["location"],
494+
},
495+
},
496+
},
497+
]
498+
499+
const stream = handler.createMessage(
500+
"System",
501+
[{ role: "user" as const, content: "What's the weather in SF?" }],
502+
{ taskId: "test", tools },
503+
)
504+
505+
const results = []
506+
for await (const chunk of stream) {
507+
results.push(chunk)
508+
}
509+
510+
// Should yield a tool_call_partial chunk
511+
const toolCallChunk = results.find((r) => r.type === "tool_call_partial")
512+
expect(toolCallChunk).toBeDefined()
513+
expect(toolCallChunk).toEqual({
514+
type: "tool_call_partial",
515+
index: 0,
516+
id: "ollama-tool-0",
517+
name: "get_weather",
518+
arguments: JSON.stringify({ location: "San Francisco" }),
519+
})
520+
})
521+
})
260522
})

0 commit comments

Comments
 (0)