Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 545500b

Browse files
authored
FEATURE: allows forced LLM tool use (#818)
* FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review
1 parent c294b6d commit 545500b

File tree

17 files changed

+236
-38
lines changed

17 files changed

+236
-38
lines changed

app/controllers/discourse_ai/admin/ai_personas_controller.rb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,12 @@ def ai_persona_params
120120
def permit_tools(tools)
121121
return [] if !tools.is_a?(Array)
122122

123-
tools.filter_map do |tool, options|
123+
tools.filter_map do |tool, options, force_tool|
124124
break nil if !tool.is_a?(String)
125125
options&.permit! if options && options.is_a?(ActionController::Parameters)
126126

127-
if options
128-
[tool, options]
129-
else
130-
tool
131-
end
127+
# this is simpler from a storage perspective, 1 way to store tools
128+
[tool, options, !!force_tool]
132129
end
133130
end
134131
end

app/models/ai_persona.rb

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,23 @@ def class_instance
136136
end
137137

138138
options = {}
139+
force_tool_use = []
140+
139141
tools =
140142
self.tools.filter_map do |element|
141143
klass = nil
142144

143-
if element.is_a?(String) && element.start_with?("custom-")
144-
custom_tool_id = element.split("-", 2).last.to_i
145+
element = [element] if element.is_a?(String)
146+
147+
inner_name, current_options, should_force_tool_use =
148+
element.is_a?(Array) ? element : [element, nil]
149+
150+
if inner_name.start_with?("custom-")
151+
custom_tool_id = inner_name.split("-", 2).last.to_i
145152
if AiTool.exists?(id: custom_tool_id, enabled: true)
146153
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
147154
end
148155
else
149-
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
150156
inner_name = inner_name.gsub("Tool", "")
151157
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
152158

@@ -155,9 +161,10 @@ def class_instance
155161
options[klass] = current_options if current_options
156162
rescue StandardError
157163
end
158-
159-
klass
160164
end
165+
166+
force_tool_use << klass if should_force_tool_use
167+
klass
161168
end
162169

163170
ai_persona_id = self.id
@@ -177,6 +184,7 @@ def class_instance
177184
end
178185

179186
define_method(:tools) { tools }
187+
define_method(:force_tool_use) { force_tool_use }
180188
define_method(:options) { options }
181189
define_method(:temperature) { @ai_persona&.temperature }
182190
define_method(:top_p) { @ai_persona&.top_p }

assets/javascripts/discourse/admin/models/ai-persona.js

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,25 @@ class ToolOption {
5959
export default class AiPersona extends RestModel {
6060
// this code is here to convert the wire schema to easier to work with object
6161
// on the wire we pass in/out tools as an Array.
62-
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
62+
// [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
6363
// So we rework this into a "tools" property and nested toolOptions
6464
init(properties) {
65+
this.forcedTools = [];
6566
if (properties.tools) {
6667
properties.tools = properties.tools.map((tool) => {
6768
if (typeof tool === "string") {
6869
return tool;
6970
} else {
70-
let [toolId, options] = tool;
71+
let [toolId, options, force] = tool;
7172
for (let optionId in options) {
7273
if (!options.hasOwnProperty(optionId)) {
7374
continue;
7475
}
7576
this.getToolOption(toolId, optionId).value = options[optionId];
7677
}
78+
if (force) {
79+
this.forcedTools.push(toolId);
80+
}
7781
return toolId;
7882
}
7983
});
@@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
109113
if (typeof toolId !== "string") {
110114
toolId = toolId[0];
111115
}
116+
117+
let force = this.forcedTools.includes(toolId);
112118
if (this.toolOptions && this.toolOptions[toolId]) {
113119
let options = this.toolOptions[toolId];
114120
let optionsWithValues = {};
@@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
119125
let option = options[optionId];
120126
optionsWithValues[optionId] = option.value;
121127
}
122-
toolsWithOptions.push([toolId, optionsWithValues]);
128+
toolsWithOptions.push([toolId, optionsWithValues, force]);
123129
} else {
124-
toolsWithOptions.push(toolId);
130+
toolsWithOptions.push([toolId, {}, force]);
125131
}
126132
});
127133
attrs.tools = toolsWithOptions;
@@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
133139
: this.getProperties(CREATE_ATTRIBUTES);
134140
attrs.id = this.id;
135141
this.populateToolOptions(attrs);
136-
137142
return attrs;
138143
}
139144

@@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
146151
workingCopy() {
147152
let attrs = this.getProperties(CREATE_ATTRIBUTES);
148153
this.populateToolOptions(attrs);
149-
return AiPersona.create(attrs);
154+
155+
const persona = AiPersona.create(attrs);
156+
persona.forcedTools = (this.forcedTools || []).slice();
157+
return persona;
150158
}
151159
}

assets/javascripts/discourse/components/ai-persona-editor.gjs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,52 @@ export default class PersonaEditor extends Component {
4040
@tracked maxPixelsValue = null;
4141
@tracked ragIndexingStatuses = null;
4242

43+
@tracked selectedTools = [];
44+
@tracked selectedToolNames = [];
45+
@tracked forcedToolNames = [];
46+
4347
get chatPluginEnabled() {
4448
return this.siteSettings.chat_enabled;
4549
}
4650

51+
get allowForceTools() {
52+
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
53+
}
54+
55+
@action
56+
forcedToolsChanged(tools) {
57+
this.forcedToolNames = tools;
58+
this.editingModel.forcedTools = this.forcedToolNames;
59+
}
60+
61+
@action
62+
toolsChanged(tools) {
63+
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
64+
tools.includes(tool.id)
65+
);
66+
this.selectedToolNames = tools.slice();
67+
68+
this.forcedToolNames = this.forcedToolNames.filter(
69+
(tool) => this.editingModel.tools.indexOf(tool) !== -1
70+
);
71+
72+
this.editingModel.tools = this.selectedToolNames;
73+
this.editingModel.forcedTools = this.forcedToolNames;
74+
}
75+
4776
@action
4877
updateModel() {
4978
this.editingModel = this.args.model.workingCopy();
5079
this.showDelete = !this.args.model.isNew && !this.args.model.system;
5180
this.maxPixelsValue = this.findClosestPixelValue(
5281
this.editingModel.vision_max_pixels
5382
);
83+
84+
this.selectedToolNames = this.editingModel.tools || [];
85+
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
86+
this.selectedToolNames.includes(tool.id)
87+
);
88+
this.forcedToolNames = this.editingModel.forcedTools || [];
5489
}
5590

5691
findClosestPixelValue(pixels) {
@@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
336371
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
337372
<AiToolSelector
338373
class="ai-persona-editor__tools"
339-
@value={{this.editingModel.tools}}
374+
@value={{this.selectedToolNames}}
340375
@disabled={{this.editingModel.system}}
341376
@tools={{@personas.resultSetMeta.tools}}
377+
@onChange={{this.toolsChanged}}
342378
/>
343379
</div>
380+
{{#if this.allowForceTools}}
381+
<div class="control-group">
382+
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
383+
<AiToolSelector
384+
class="ai-persona-editor__tools"
385+
@value={{this.forcedToolNames}}
386+
@tools={{this.selectedTools}}
387+
@onChange={{this.forcedToolsChanged}}
388+
/>
389+
</div>
390+
{{/if}}
344391
{{#unless this.editingModel.system}}
345392
<AiPersonaToolOptions
346393
@persona={{this.editingModel}}
347-
@tools={{this.editingModel.tools}}
394+
@tools={{this.selectedToolNames}}
348395
@allTools={{@personas.resultSetMeta.tools}}
349396
/>
350397
{{/unless}}

assets/javascripts/discourse/components/ai-tool-selector.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
66
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
77
}),
88

9-
content: computed(function () {
9+
content: computed("tools", function () {
1010
return this.tools;
1111
}),
1212

config/locales/client.en.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ en:
148148
saved: AI Persona Saved
149149
enabled: "Enabled?"
150150
tools: Enabled Tools
151+
forced_tools: Forced Tools
151152
allowed_groups: Allowed Groups
152153
confirm_delete: Are you sure you want to delete this persona?
153154
new: "New Persona"

lib/ai_bot/bot.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ def get_updated_title(conversation_context, post)
6767
.last
6868
end
6969

70+
def force_tool_if_needed(prompt, context)
71+
context[:chosen_tools] ||= []
72+
forced_tools = persona.force_tool_use.map { |tool| tool.name }
73+
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
74+
75+
if force_tool
76+
context[:chosen_tools] << force_tool
77+
prompt.tool_choice = force_tool
78+
else
79+
prompt.tool_choice = nil
80+
end
81+
end
82+
7083
def reply(context, &update_blk)
7184
llm = DiscourseAi::Completions::Llm.proxy(model)
7285
prompt = persona.craft_prompt(context, llm: llm)
@@ -85,6 +98,7 @@ def reply(context, &update_blk)
8598

8699
while total_completions <= MAX_COMPLETIONS && ongoing_chain
87100
tool_found = false
101+
force_tool_if_needed(prompt, context)
88102

89103
result =
90104
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|

lib/ai_bot/personas/persona.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def tools
113113
[]
114114
end
115115

116+
def force_tool_use
117+
[]
118+
end
119+
116120
def required_tools
117121
[]
118122
end

lib/completions/dialects/dialect.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def tools
6060
@tools ||= tools_dialect.translated_tools
6161
end
6262

63+
def tool_choice
64+
prompt.tool_choice
65+
end
66+
6367
def translate
6468
messages = prompt.messages
6569

lib/completions/endpoints/open_ai.rb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ def prepare_payload(prompt, model_params, dialect)
5454
# We'll fallback to guess this using the tokenizer.
5555
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
5656
end
57-
58-
payload[:tools] = dialect.tools if dialect.tools.present?
57+
if dialect.tools.present?
58+
payload[:tools] = dialect.tools
59+
if dialect.tool_choice.present?
60+
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
61+
end
62+
end
5963
payload
6064
end
6165

0 commit comments

Comments
 (0)