Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 94010a5

Browse files
authored
FEATURE: Tools for models from Ollama provider (#819)
Adds support for Ollama function calling
1 parent 6c4c96e commit 94010a5

File tree

9 files changed

+404
-20
lines changed

9 files changed

+404
-20
lines changed

app/models/llm_model.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def self.provider_params
3131
},
3232
ollama: {
3333
disable_system_prompt: :checkbox,
34+
enable_native_tool: :checkbox,
3435
},
3536
}
3637
end

config/locales/client.en.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ en:
312312
region: "AWS Bedrock Region"
313313
organization: "Optional OpenAI Organization ID"
314314
disable_system_prompt: "Disable system message in prompts"
315+
enable_native_tool: "Enable native tool support"
315316

316317
related_topics:
317318
title: "Related Topics"

lib/completions/dialects/ollama.rb

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@ def can_translate?(model_provider)
1010
end
1111
end
1212

13-
# TODO: Add tool suppport
13+
def native_tool_support?
14+
enable_native_tool?
15+
end
1416

1517
def max_prompt_tokens
1618
llm_model.max_prompt_tokens
1719
end
1820

1921
private
2022

23+
def tools_dialect
24+
if enable_native_tool?
25+
@tools_dialect ||= DiscourseAi::Completions::Dialects::OllamaTools.new(prompt.tools)
26+
else
27+
super
28+
end
29+
end
30+
2131
def tokenizer
2232
llm_model.tokenizer_class
2333
end
@@ -26,8 +36,28 @@ def model_msg(msg)
2636
{ role: "assistant", content: msg[:content] }
2737
end
2838

39+
def tool_call_msg(msg)
40+
tools_dialect.from_raw_tool_call(msg)
41+
end
42+
43+
def tool_msg(msg)
44+
tools_dialect.from_raw_tool(msg)
45+
end
46+
2947
def system_msg(msg)
30-
{ role: "system", content: msg[:content] }
48+
msg = { role: "system", content: msg[:content] }
49+
50+
if tools_dialect.instructions.present?
51+
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
52+
end
53+
54+
msg
55+
end
56+
57+
def enable_native_tool?
58+
return @enable_native_tool if defined?(@enable_native_tool)
59+
60+
@enable_native_tool = llm_model.lookup_custom_param("enable_native_tool")
3161
end
3262

3363
def user_msg(msg)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
module Dialects
6+
# TODO: Define the Tool class to be inherited by all tools.
7+
class OllamaTools
8+
def initialize(tools)
9+
@raw_tools = tools
10+
end
11+
12+
def instructions
13+
"" # Noop. Tools are listed separate.
14+
end
15+
16+
def translated_tools
17+
raw_tools.map do |t|
18+
tool = t.dup
19+
20+
tool[:parameters] = t[:parameters]
21+
.to_a
22+
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
23+
name = p[:name]
24+
memo[:required] << name if p[:required]
25+
26+
except = %i[name required item_type]
27+
except << :enum if p[:enum].blank?
28+
29+
memo[:properties][name] = p.except(*except)
30+
memo
31+
end
32+
33+
{ type: "function", function: tool }
34+
end
35+
end
36+
37+
def from_raw_tool_call(raw_message)
38+
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
39+
call_details[:name] = raw_message[:name]
40+
41+
{
42+
role: "assistant",
43+
content: nil,
44+
tool_calls: [{ type: "function", function: call_details }],
45+
}
46+
end
47+
48+
def from_raw_tool(raw_message)
49+
{ role: "tool", content: raw_message[:content], name: raw_message[:name] }
50+
end
51+
52+
private
53+
54+
attr_reader :raw_tools
55+
end
56+
end
57+
end
58+
end

lib/completions/endpoints/ollama.rb

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,28 @@ def model_uri
3737
URI(llm_model.url)
3838
end
3939

40-
def prepare_payload(prompt, model_params, _dialect)
40+
def native_tool_support?
41+
@native_tool_support
42+
end
43+
44+
def has_tool?(_response_data)
45+
@has_function_call
46+
end
47+
48+
def prepare_payload(prompt, model_params, dialect)
49+
@native_tool_support = dialect.native_tool_support?
50+
51+
# https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1
52+
# Due to ollama enforce a 'stream: false' for tool calls, instead of complicating the code,
53+
# we will just disable streaming for all ollama calls if native tool support is enabled
54+
4155
default_options
4256
.merge(model_params)
4357
.merge(messages: prompt)
44-
.tap { |payload| payload[:stream] = false if !@streaming_mode }
58+
.tap { |payload| payload[:stream] = false if @native_tool_support || !@streaming_mode }
59+
.tap do |payload|
60+
payload[:tools] = dialect.tools if @native_tool_support && dialect.tools.present?
61+
end
4562
end
4663

4764
def prepare_request(payload)
@@ -58,7 +75,66 @@ def extract_completion_from(response_raw)
5875
parsed = JSON.parse(response_raw, symbolize_names: true)
5976
return if !parsed
6077

61-
parsed.dig(:message, :content)
78+
response_h = parsed.dig(:message)
79+
80+
@has_function_call ||= response_h.dig(:tool_calls).present?
81+
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
82+
end
83+
84+
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
85+
@args_buffer ||= +""
86+
87+
if @streaming_mode
88+
return function_buffer if !partial
89+
else
90+
partial = payload
91+
end
92+
93+
f_name = partial.dig(:function, :name)
94+
95+
@current_function ||= function_buffer.at("invoke")
96+
97+
if f_name
98+
current_name = function_buffer.at("tool_name").content
99+
100+
if current_name.blank?
101+
# first call
102+
else
103+
# we have a previous function, so we need to add a noop
104+
@args_buffer = +""
105+
@current_function =
106+
function_buffer.at("function_calls").add_child(
107+
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
108+
)
109+
end
110+
end
111+
112+
@current_function.at("tool_name").content = f_name if f_name
113+
@current_function.at("tool_id").content = partial[:id] if partial[:id]
114+
115+
args = partial.dig(:function, :arguments)
116+
117+
# allow for SPACE within arguments
118+
if args && args != ""
119+
@args_buffer << args.to_json
120+
121+
begin
122+
json_args = JSON.parse(@args_buffer, symbolize_names: true)
123+
124+
argument_fragments =
125+
json_args.reduce(+"") do |memo, (arg_name, value)|
126+
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
127+
end
128+
argument_fragments << "\n"
129+
130+
@current_function.at("parameters").children =
131+
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
132+
rescue JSON::ParserError
133+
return function_buffer
134+
end
135+
end
136+
137+
function_buffer
62138
end
63139
end
64140
end

spec/fabricators/llm_model_fabricator.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,5 @@
8787
api_key "ABC"
8888
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
8989
url "http://api.ollama.ai/api/chat"
90+
provider_params { { enable_native_tool: true } }
9091
end

spec/lib/completions/dialects/ollama_spec.rb

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,37 @@
77
let(:context) { DialectContext.new(described_class, model) }
88

99
describe "#translate" do
10-
it "translates a prompt written in our generic format to the Ollama format" do
11-
ollama_version = [
12-
{ role: "system", content: context.system_insts },
13-
{ role: "user", content: context.simple_user_input },
14-
]
10+
context "when native tool support is enabled" do
11+
it "translates a prompt written in our generic format to the Ollama format" do
12+
ollama_version = [
13+
{ role: "system", content: context.system_insts },
14+
{ role: "user", content: context.simple_user_input },
15+
]
1516

16-
translated = context.system_user_scenario
17+
translated = context.system_user_scenario
1718

18-
expect(translated).to eq(ollama_version)
19+
expect(translated).to eq(ollama_version)
20+
end
21+
end
22+
23+
context "when native tool support is disabled - XML tools" do
24+
it "includes the instructions in the system message" do
25+
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
26+
27+
DiscourseAi::Completions::Dialects::XmlTools
28+
.any_instance
29+
.stubs(:instructions)
30+
.returns("Instructions")
31+
32+
ollama_version = [
33+
{ role: "system", content: "#{context.system_insts}\n\nInstructions" },
34+
{ role: "user", content: context.simple_user_input },
35+
]
36+
37+
translated = context.system_user_scenario
38+
39+
expect(translated).to eq(ollama_version)
40+
end
1941
end
2042

2143
it "trims content if it's getting too long" do
@@ -33,4 +55,40 @@
3355
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
3456
end
3557
end
58+
59+
describe "#tools" do
60+
context "when native tools are enabled" do
61+
it "returns the translated tools from the OllamaTools class" do
62+
tool = instance_double(DiscourseAi::Completions::Dialects::OllamaTools)
63+
64+
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(true)
65+
allow(tool).to receive(:translated_tools)
66+
allow(DiscourseAi::Completions::Dialects::OllamaTools).to receive(:new).and_return(tool)
67+
68+
context.dialect_tools
69+
70+
expect(DiscourseAi::Completions::Dialects::OllamaTools).to have_received(:new).with(
71+
context.prompt.tools,
72+
)
73+
expect(tool).to have_received(:translated_tools)
74+
end
75+
end
76+
77+
context "when native tools are disabled" do
78+
it "returns the translated tools from the XmlTools class" do
79+
tool = instance_double(DiscourseAi::Completions::Dialects::XmlTools)
80+
81+
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
82+
allow(tool).to receive(:translated_tools)
83+
allow(DiscourseAi::Completions::Dialects::XmlTools).to receive(:new).and_return(tool)
84+
85+
context.dialect_tools
86+
87+
expect(DiscourseAi::Completions::Dialects::XmlTools).to have_received(:new).with(
88+
context.prompt.tools,
89+
)
90+
expect(tool).to have_received(:translated_tools)
91+
end
92+
end
93+
end
3694
end

0 commit comments

Comments
 (0)