Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit e817b7d

Browse files
authored
FEATURE: improve tool support (#904)
This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens)
1 parent 644141f commit e817b7d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1685
-1293
lines changed

app/controllers/discourse_ai/ai_bot/bot_controller.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ class BotController < ::ApplicationController
66
requires_plugin ::DiscourseAi::PLUGIN_NAME
77
requires_login
88

9+
def show_debug_info_by_id
10+
log = AiApiAuditLog.find(params[:id])
11+
raise Discourse::NotFound if !log.topic
12+
13+
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
14+
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
15+
end
16+
917
def show_debug_info
1018
post = Post.find(params[:post_id])
1119
guardian.ensure_can_debug_ai_bot_conversation!(post)

app/models/ai_api_audit_log.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ module Provider
1414
Ollama = 7
1515
SambaNova = 8
1616
end
17+
18+
def next_log_id
19+
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
20+
end
21+
22+
def prev_log_id
23+
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
24+
end
1725
end
1826

1927
# == Schema Information

app/serializers/ai_api_audit_log_serializer.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
1212
:post_id,
1313
:feature_name,
1414
:language_model,
15-
:created_at
15+
:created_at,
16+
:prev_log_id,
17+
:next_log_id
1618
end

assets/javascripts/discourse/components/modal/debug-ai-modal.gjs

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
77
import DButton from "discourse/components/d-button";
88
import DModal from "discourse/components/d-modal";
99
import { ajax } from "discourse/lib/ajax";
10+
import { popupAjaxError } from "discourse/lib/ajax-error";
1011
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
1112
import i18n from "discourse-common/helpers/i18n";
1213
import discourseLater from "discourse-common/lib/later";
@@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
6364
this.copy(this.info.raw_response_payload);
6465
}
6566

67+
async loadLog(logId) {
68+
try {
69+
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
70+
(result) => {
71+
this.info = result;
72+
}
73+
);
74+
} catch (e) {
75+
popupAjaxError(e);
76+
}
77+
}
78+
79+
@action
80+
prevLog() {
81+
this.loadLog(this.info.prev_log_id);
82+
}
83+
84+
@action
85+
nextLog() {
86+
this.loadLog(this.info.next_log_id);
87+
}
88+
6689
copy(text) {
6790
clipboardCopy(text);
6891
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@@ -73,11 +96,13 @@ export default class DebugAiModal extends Component {
7396
}
7497

7598
loadApiRequestInfo() {
76-
ajax(
77-
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
78-
).then((result) => {
79-
this.info = result;
80-
});
99+
ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
100+
.then((result) => {
101+
this.info = result;
102+
})
103+
.catch((e) => {
104+
popupAjaxError(e);
105+
});
81106
}
82107

83108
get requestActive() {
@@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
147172
@action={{this.copyResponse}}
148173
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
149174
/>
175+
{{#if this.info.prev_log_id}}
176+
<DButton
177+
class="btn"
178+
@icon="angles-left"
179+
@action={{this.prevLog}}
180+
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
181+
/>
182+
{{/if}}
183+
{{#if this.info.next_log_id}}
184+
<DButton
185+
class="btn"
186+
@icon="angles-right"
187+
@action={{this.nextLog}}
188+
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
189+
/>
190+
{{/if}}
150191
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
151192
</:footer>
152193
</DModal>

config/locales/client.en.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ en:
415415
response_tokens: "Response tokens:"
416416
request: "Request"
417417
response: "Response"
418+
next_log: "Next"
419+
previous_log: "Previous"
418420

419421
share_full_topic_modal:
420422
title: "Share Conversation Publicly"

config/routes.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
2323
get "bot-username" => "bot#show_bot_username"
2424
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
25+
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
2526
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
2627
end
2728

lib/ai_bot/bot.rb

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,30 +100,35 @@ def reply(context, &update_blk)
100100
llm_kwargs[:top_p] = persona.top_p if persona.top_p
101101

102102
needs_newlines = false
103+
tools_ran = 0
103104

104105
while total_completions <= MAX_COMPLETIONS && ongoing_chain
105106
tool_found = false
106107
force_tool_if_needed(prompt, context)
107108

108109
result =
109110
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
110-
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
111+
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
112+
tool = nil if tools_ran >= MAX_TOOLS
111113

112-
if (tools.present?)
114+
if tool.present?
113115
tool_found = true
114116
# a bit hacky, but extra newlines do no harm
115117
if needs_newlines
116118
update_blk.call("\n\n", cancel)
117119
needs_newlines = false
118120
end
119121

120-
tools[0..MAX_TOOLS].each do |tool|
121-
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
122-
ongoing_chain &&= tool.chain_next_response?
123-
end
122+
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
123+
tools_ran += 1
124+
ongoing_chain &&= tool.chain_next_response?
124125
else
125126
needs_newlines = true
126-
update_blk.call(partial, cancel)
127+
if partial.is_a?(DiscourseAi::Completions::ToolCall)
128+
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
129+
else
130+
update_blk.call(partial, cancel)
131+
end
127132
end
128133
end
129134

lib/ai_bot/personas/persona.rb

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,16 @@ def craft_prompt(context, llm: nil)
199199
prompt
200200
end
201201

202-
def find_tools(partial, bot_user:, llm:, context:)
203-
return [] if !partial.include?("</invoke>")
204-
205-
parsed_function = Nokogiri::HTML5.fragment(partial)
206-
parsed_function
207-
.css("invoke")
208-
.map do |fragment|
209-
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
210-
end
211-
.compact
202+
def find_tool(partial, bot_user:, llm:, context:)
203+
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
204+
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
212205
end
213206

214207
protected
215208

216-
def tool_instance(parsed_function, bot_user:, llm:, context:)
217-
function_id = parsed_function.at("tool_id")&.text
218-
function_name = parsed_function.at("tool_name")&.text
209+
def tool_instance(tool_call, bot_user:, llm:, context:)
210+
function_id = tool_call.id
211+
function_name = tool_call.name
219212
return nil if function_name.nil?
220213

221214
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@@ -224,7 +217,7 @@ def tool_instance(parsed_function, bot_user:, llm:, context:)
224217
arguments = {}
225218
tool_klass.signature[:parameters].to_a.each do |param|
226219
name = param[:name]
227-
value = parsed_function.at(name)&.text
220+
value = tool_call.parameters[name.to_sym]
228221

229222
if param[:type] == "array" && value
230223
value =

lib/completions/anthropic_message_processor.rb

Lines changed: 56 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,87 +13,81 @@ def initialize(name, id)
1313
def append(json)
1414
@raw_json << json
1515
end
16+
17+
def to_tool_call
18+
parameters = JSON.parse(raw_json, symbolize_names: true)
19+
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
20+
end
1621
end
1722

1823
attr_reader :tool_calls, :input_tokens, :output_tokens
1924

2025
def initialize(streaming_mode:)
2126
@streaming_mode = streaming_mode
2227
@tool_calls = []
28+
@current_tool_call = nil
2329
end
2430

25-
def to_xml_tool_calls(function_buffer)
26-
return function_buffer if @tool_calls.blank?
27-
28-
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
29-
<function_calls>
30-
</function_calls>
31-
TEXT
32-
33-
@tool_calls.each do |tool_call|
34-
node =
35-
function_buffer.at("function_calls").add_child(
36-
Nokogiri::HTML5::DocumentFragment.parse(
37-
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
38-
),
39-
)
40-
41-
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
42-
xml =
43-
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
31+
def to_tool_calls
32+
@tool_calls.map { |tool_call| tool_call.to_tool_call }
33+
end
4434

45-
node.at("tool_name").content = tool_call.name
46-
node.at("tool_id").content = tool_call.id
47-
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
35+
def process_streamed_message(parsed)
36+
result = nil
37+
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
38+
tool_name = parsed.dig(:content_block, :name)
39+
tool_id = parsed.dig(:content_block, :id)
40+
result = @current_tool_call.to_tool_call if @current_tool_call
41+
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
42+
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
43+
if @current_tool_call
44+
tool_delta = parsed.dig(:delta, :partial_json).to_s
45+
@current_tool_call.append(tool_delta)
46+
else
47+
result = parsed.dig(:delta, :text).to_s
48+
end
49+
elsif parsed[:type] == "content_block_stop"
50+
if @current_tool_call
51+
result = @current_tool_call.to_tool_call
52+
@current_tool_call = nil
53+
end
54+
elsif parsed[:type] == "message_start"
55+
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
56+
elsif parsed[:type] == "message_delta"
57+
@output_tokens =
58+
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
59+
elsif parsed[:type] == "message_stop"
60+
# bedrock has this ...
61+
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
62+
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
63+
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
64+
end
4865
end
49-
50-
function_buffer
66+
result
5167
end
5268

5369
def process_message(payload)
5470
result = ""
55-
parsed = JSON.parse(payload, symbolize_names: true)
71+
parsed = payload
72+
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
5673

57-
if @streaming_mode
58-
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
59-
tool_name = parsed.dig(:content_block, :name)
60-
tool_id = parsed.dig(:content_block, :id)
61-
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
62-
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
63-
if @tool_calls.present?
64-
result = parsed.dig(:delta, :partial_json).to_s
65-
@tool_calls.last.append(result)
66-
else
67-
result = parsed.dig(:delta, :text).to_s
68-
end
69-
elsif parsed[:type] == "message_start"
70-
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
71-
elsif parsed[:type] == "message_delta"
72-
@output_tokens =
73-
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
74-
elsif parsed[:type] == "message_stop"
75-
# bedrock has this ...
76-
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
77-
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
78-
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
79-
end
80-
end
81-
else
82-
content = parsed.dig(:content)
83-
if content.is_a?(Array)
84-
tool_call = content.find { |c| c[:type] == "tool_use" }
85-
if tool_call
86-
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
87-
@tool_calls.last.append(tool_call[:input].to_json)
88-
else
89-
result = parsed.dig(:content, 0, :text).to_s
74+
content = parsed.dig(:content)
75+
if content.is_a?(Array)
76+
result =
77+
content.map do |data|
78+
if data[:type] == "tool_use"
79+
call = AnthropicToolCall.new(data[:name], data[:id])
80+
call.append(data[:input].to_json)
81+
call.to_tool_call
82+
else
83+
data[:text]
84+
end
9085
end
91-
end
92-
93-
@input_tokens = parsed.dig(:usage, :input_tokens)
94-
@output_tokens = parsed.dig(:usage, :output_tokens)
9586
end
9687

88+
@input_tokens = parsed.dig(:usage, :input_tokens)
89+
@output_tokens = parsed.dig(:usage, :output_tokens)
90+
9791
result
9892
end
9993
end

lib/completions/dialects/ollama.rb

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,23 @@ def enable_native_tool?
6363
def user_msg(msg)
6464
user_message = { role: "user", content: msg[:content] }
6565

66-
# TODO: Add support for user messages with empbeded user ids
67-
# TODO: Add support for user messages with attachments
66+
encoded_uploads = prompt.encoded_uploads(msg)
67+
if encoded_uploads.present?
68+
images =
69+
encoded_uploads
70+
.map do |upload|
71+
if upload[:mime_type].start_with?("image/")
72+
upload[:base64]
73+
else
74+
nil
75+
end
76+
end
77+
.compact
78+
79+
user_message[:images] = images if images.present?
80+
end
81+
82+
# TODO: Add support for user messages with embedded user ids
6883

6984
user_message
7085
end

0 commit comments

Comments
 (0)