22
33import { NativeOllamaHandler } from "../native-ollama"
44import { ApiHandlerOptions } from "../../../shared/api"
5+ import { getOllamaModels } from "../fetchers/ollama"
56
67// Mock the ollama package
78const mockChat = vitest . fn ( )
@@ -16,22 +17,27 @@ vitest.mock("ollama", () => {
1617
1718// Mock the getOllamaModels function
1819vitest . mock ( "../fetchers/ollama" , ( ) => ( {
19- getOllamaModels : vitest . fn ( ) . mockResolvedValue ( {
20- llama2 : {
21- contextWindow : 4096 ,
22- maxTokens : 4096 ,
23- supportsImages : false ,
24- supportsPromptCache : false ,
25- } ,
26- } ) ,
20+ getOllamaModels : vitest . fn ( ) ,
2721} ) )
2822
23+ const mockGetOllamaModels = vitest . mocked ( getOllamaModels )
24+
2925describe ( "NativeOllamaHandler" , ( ) => {
3026 let handler : NativeOllamaHandler
3127
3228 beforeEach ( ( ) => {
3329 vitest . clearAllMocks ( )
3430
31+ // Default mock for getOllamaModels
32+ mockGetOllamaModels . mockResolvedValue ( {
33+ llama2 : {
34+ contextWindow : 4096 ,
35+ maxTokens : 4096 ,
36+ supportsImages : false ,
37+ supportsPromptCache : false ,
38+ } ,
39+ } )
40+
3541 const options : ApiHandlerOptions = {
3642 apiModelId : "llama2" ,
3743 ollamaModelId : "llama2" ,
@@ -257,4 +263,260 @@ describe("NativeOllamaHandler", () => {
257263 expect ( model . info ) . toBeDefined ( )
258264 } )
259265 } )
266+
267+ describe ( "tool calling" , ( ) => {
268+ it ( "should include tools when model supports native tools" , async ( ) => {
269+ // Mock model with native tool support
270+ mockGetOllamaModels . mockResolvedValue ( {
271+ "llama3.2" : {
272+ contextWindow : 128000 ,
273+ maxTokens : 4096 ,
274+ supportsImages : true ,
275+ supportsPromptCache : false ,
276+ supportsNativeTools : true ,
277+ } ,
278+ } )
279+
280+ const options : ApiHandlerOptions = {
281+ apiModelId : "llama3.2" ,
282+ ollamaModelId : "llama3.2" ,
283+ ollamaBaseUrl : "http://localhost:11434" ,
284+ }
285+
286+ handler = new NativeOllamaHandler ( options )
287+
288+ // Mock the chat response
289+ mockChat . mockImplementation ( async function * ( ) {
290+ yield { message : { content : "I will use the tool" } }
291+ } )
292+
293+ const tools = [
294+ {
295+ type : "function" as const ,
296+ function : {
297+ name : "get_weather" ,
298+ description : "Get the weather for a location" ,
299+ parameters : {
300+ type : "object" ,
301+ properties : {
302+ location : { type : "string" , description : "The city name" } ,
303+ } ,
304+ required : [ "location" ] ,
305+ } ,
306+ } ,
307+ } ,
308+ ]
309+
310+ const stream = handler . createMessage (
311+ "System" ,
312+ [ { role : "user" as const , content : "What's the weather?" } ] ,
313+ { taskId : "test" , tools } ,
314+ )
315+
316+ // Consume the stream
317+ for await ( const _ of stream ) {
318+ // consume stream
319+ }
320+
321+ // Verify tools were passed to the API
322+ expect ( mockChat ) . toHaveBeenCalledWith (
323+ expect . objectContaining ( {
324+ tools : [
325+ {
326+ type : "function" ,
327+ function : {
328+ name : "get_weather" ,
329+ description : "Get the weather for a location" ,
330+ parameters : {
331+ type : "object" ,
332+ properties : {
333+ location : { type : "string" , description : "The city name" } ,
334+ } ,
335+ required : [ "location" ] ,
336+ } ,
337+ } ,
338+ } ,
339+ ] ,
340+ } ) ,
341+ )
342+ } )
343+
344+ it ( "should not include tools when model does not support native tools" , async ( ) => {
345+ // Mock model without native tool support
346+ mockGetOllamaModels . mockResolvedValue ( {
347+ llama2 : {
348+ contextWindow : 4096 ,
349+ maxTokens : 4096 ,
350+ supportsImages : false ,
351+ supportsPromptCache : false ,
352+ supportsNativeTools : false ,
353+ } ,
354+ } )
355+
356+ // Mock the chat response
357+ mockChat . mockImplementation ( async function * ( ) {
358+ yield { message : { content : "Response without tools" } }
359+ } )
360+
361+ const tools = [
362+ {
363+ type : "function" as const ,
364+ function : {
365+ name : "get_weather" ,
366+ description : "Get the weather" ,
367+ parameters : { type : "object" , properties : { } } ,
368+ } ,
369+ } ,
370+ ]
371+
372+ const stream = handler . createMessage ( "System" , [ { role : "user" as const , content : "Test" } ] , {
373+ taskId : "test" ,
374+ tools,
375+ } )
376+
377+ // Consume the stream
378+ for await ( const _ of stream ) {
379+ // consume stream
380+ }
381+
382+ // Verify tools were NOT passed
383+ expect ( mockChat ) . toHaveBeenCalledWith (
384+ expect . not . objectContaining ( {
385+ tools : expect . anything ( ) ,
386+ } ) ,
387+ )
388+ } )
389+
390+ it ( "should not include tools when toolProtocol is xml" , async ( ) => {
391+ // Mock model with native tool support
392+ mockGetOllamaModels . mockResolvedValue ( {
393+ "llama3.2" : {
394+ contextWindow : 128000 ,
395+ maxTokens : 4096 ,
396+ supportsImages : true ,
397+ supportsPromptCache : false ,
398+ supportsNativeTools : true ,
399+ } ,
400+ } )
401+
402+ const options : ApiHandlerOptions = {
403+ apiModelId : "llama3.2" ,
404+ ollamaModelId : "llama3.2" ,
405+ ollamaBaseUrl : "http://localhost:11434" ,
406+ }
407+
408+ handler = new NativeOllamaHandler ( options )
409+
410+ // Mock the chat response
411+ mockChat . mockImplementation ( async function * ( ) {
412+ yield { message : { content : "Response" } }
413+ } )
414+
415+ const tools = [
416+ {
417+ type : "function" as const ,
418+ function : {
419+ name : "get_weather" ,
420+ description : "Get the weather" ,
421+ parameters : { type : "object" , properties : { } } ,
422+ } ,
423+ } ,
424+ ]
425+
426+ const stream = handler . createMessage ( "System" , [ { role : "user" as const , content : "Test" } ] , {
427+ taskId : "test" ,
428+ tools,
429+ toolProtocol : "xml" ,
430+ } )
431+
432+ // Consume the stream
433+ for await ( const _ of stream ) {
434+ // consume stream
435+ }
436+
437+ // Verify tools were NOT passed (XML protocol forces XML format)
438+ expect ( mockChat ) . toHaveBeenCalledWith (
439+ expect . not . objectContaining ( {
440+ tools : expect . anything ( ) ,
441+ } ) ,
442+ )
443+ } )
444+
445+ it ( "should yield tool_call_partial when model returns tool calls" , async ( ) => {
446+ // Mock model with native tool support
447+ mockGetOllamaModels . mockResolvedValue ( {
448+ "llama3.2" : {
449+ contextWindow : 128000 ,
450+ maxTokens : 4096 ,
451+ supportsImages : true ,
452+ supportsPromptCache : false ,
453+ supportsNativeTools : true ,
454+ } ,
455+ } )
456+
457+ const options : ApiHandlerOptions = {
458+ apiModelId : "llama3.2" ,
459+ ollamaModelId : "llama3.2" ,
460+ ollamaBaseUrl : "http://localhost:11434" ,
461+ }
462+
463+ handler = new NativeOllamaHandler ( options )
464+
465+ // Mock the chat response with tool calls
466+ mockChat . mockImplementation ( async function * ( ) {
467+ yield {
468+ message : {
469+ content : "" ,
470+ tool_calls : [
471+ {
472+ function : {
473+ name : "get_weather" ,
474+ arguments : { location : "San Francisco" } ,
475+ } ,
476+ } ,
477+ ] ,
478+ } ,
479+ }
480+ } )
481+
482+ const tools = [
483+ {
484+ type : "function" as const ,
485+ function : {
486+ name : "get_weather" ,
487+ description : "Get the weather for a location" ,
488+ parameters : {
489+ type : "object" ,
490+ properties : {
491+ location : { type : "string" } ,
492+ } ,
493+ required : [ "location" ] ,
494+ } ,
495+ } ,
496+ } ,
497+ ]
498+
499+ const stream = handler . createMessage (
500+ "System" ,
501+ [ { role : "user" as const , content : "What's the weather in SF?" } ] ,
502+ { taskId : "test" , tools } ,
503+ )
504+
505+ const results = [ ]
506+ for await ( const chunk of stream ) {
507+ results . push ( chunk )
508+ }
509+
510+ // Should yield a tool_call_partial chunk
511+ const toolCallChunk = results . find ( ( r ) => r . type === "tool_call_partial" )
512+ expect ( toolCallChunk ) . toBeDefined ( )
513+ expect ( toolCallChunk ) . toEqual ( {
514+ type : "tool_call_partial" ,
515+ index : 0 ,
516+ id : "ollama-tool-0" ,
517+ name : "get_weather" ,
518+ arguments : JSON . stringify ( { location : "San Francisco" } ) ,
519+ } )
520+ } )
521+ } )
260522} )
0 commit comments