Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit bb6df42

Browse files
committed
FEATURE: improve tool support
This work in progress PR amends llm completion so it returns objects for tools vs XML fragments This will empower future features such as parameter streaming XML was error prone, object implementation is more robust Still very much in progress, a lot of code needs to change Partially implemented on Anthropic at the moment.
1 parent 1ad5321 commit bb6df42

File tree

18 files changed

+485
-140
lines changed

18 files changed

+485
-140
lines changed

app/controllers/discourse_ai/ai_bot/bot_controller.rb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ class BotController < ::ApplicationController
66
requires_plugin ::DiscourseAi::PLUGIN_NAME
77
requires_login
88

9+
def show_debug_info_by_id
10+
log = AiApiAuditLog.find(params[:id])
11+
if !log.topic
12+
raise Discourse::NotFound
13+
end
14+
15+
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
16+
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
17+
end
18+
919
def show_debug_info
1020
post = Post.find(params[:post_id])
1121
guardian.ensure_can_debug_ai_bot_conversation!(post)

app/models/ai_api_audit_log.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ module Provider
1414
Ollama = 7
1515
SambaNova = 8
1616
end
17+
18+
def next_log_id
19+
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
20+
end
21+
22+
def prev_log_id
23+
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
24+
end
1725
end
1826

1927
# == Schema Information

app/serializers/ai_api_audit_log_serializer.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
1212
:post_id,
1313
:feature_name,
1414
:language_model,
15-
:created_at
15+
:created_at,
16+
:prev_log_id,
17+
:next_log_id
1618
end

assets/javascripts/discourse/components/modal/debug-ai-modal.gjs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
77
import DButton from "discourse/components/d-button";
88
import DModal from "discourse/components/d-modal";
99
import { ajax } from "discourse/lib/ajax";
10+
import { popupAjaxError } from "discourse/lib/ajax-error";
1011
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
1112
import i18n from "discourse-common/helpers/i18n";
1213
import discourseLater from "discourse-common/lib/later";
@@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
6364
this.copy(this.info.raw_response_payload);
6465
}
6566

67+
async loadLog(logId) {
68+
try {
69+
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
70+
(result) => {
71+
this.info = result;
72+
}
73+
);
74+
} catch (e) {
75+
popupAjaxError(e);
76+
}
77+
}
78+
79+
@action
80+
prevLog() {
81+
this.loadLog(this.info.prev_log_id);
82+
}
83+
84+
@action
85+
nextLog() {
86+
this.loadLog(this.info.next_log_id);
87+
}
88+
6689
copy(text) {
6790
clipboardCopy(text);
6891
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@@ -77,6 +100,8 @@ export default class DebugAiModal extends Component {
77100
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
78101
).then((result) => {
79102
this.info = result;
103+
}).catch((e) => {
104+
popupAjaxError(e);
80105
});
81106
}
82107

@@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
147172
@action={{this.copyResponse}}
148173
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
149174
/>
175+
{{#if this.info.prev_log_id}}
176+
<DButton
177+
class="btn"
178+
@icon="angles-left"
179+
@action={{this.prevLog}}
180+
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
181+
/>
182+
{{/if}}
183+
{{#if this.info.next_log_id}}
184+
<DButton
185+
class="btn"
186+
@icon="angles-right"
187+
@action={{this.nextLog}}
188+
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
189+
/>
190+
{{/if}}
150191
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
151192
</:footer>
152193
</DModal>

config/locales/client.en.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ en:
415415
response_tokens: "Response tokens:"
416416
request: "Request"
417417
response: "Response"
418+
next_log: "Next"
419+
previous_log: "Previous"
418420

419421
share_full_topic_modal:
420422
title: "Share Conversation Publicly"

config/routes.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
2323
get "bot-username" => "bot#show_bot_username"
2424
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
25+
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
2526
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
2627
end
2728

lib/ai_bot/bot.rb

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,28 @@ def reply(context, &update_blk)
100100
llm_kwargs[:top_p] = persona.top_p if persona.top_p
101101

102102
needs_newlines = false
103+
tools_ran = 0
103104

104105
while total_completions <= MAX_COMPLETIONS && ongoing_chain
105106
tool_found = false
106107
force_tool_if_needed(prompt, context)
107108

108109
result =
109110
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
110-
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
111+
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
112+
tool = nil if tools_ran >= MAX_TOOLS
111113

112-
if (tools.present?)
114+
if (tool.present?)
113115
tool_found = true
114116
# a bit hacky, but extra newlines do no harm
115117
if needs_newlines
116118
update_blk.call("\n\n", cancel)
117119
needs_newlines = false
118120
end
119121

120-
tools[0..MAX_TOOLS].each do |tool|
121-
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
122-
ongoing_chain &&= tool.chain_next_response?
123-
end
122+
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
123+
tools_ran += 1
124+
ongoing_chain &&= tool.chain_next_response?
124125
else
125126
needs_newlines = true
126127
update_blk.call(partial, cancel)

lib/ai_bot/personas/persona.rb

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,17 @@ def craft_prompt(context, llm: nil)
199199
prompt
200200
end
201201

202-
def find_tools(partial, bot_user:, llm:, context:)
203-
return [] if !partial.include?("</invoke>")
204-
205-
parsed_function = Nokogiri::HTML5.fragment(partial)
206-
parsed_function
207-
.css("invoke")
208-
.map do |fragment|
209-
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
210-
end
211-
.compact
202+
def find_tool(partial, bot_user:, llm:, context:)
203+
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
204+
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
212205
end
213206

214207
protected
215208

216-
def tool_instance(parsed_function, bot_user:, llm:, context:)
217-
function_id = parsed_function.at("tool_id")&.text
218-
function_name = parsed_function.at("tool_name")&.text
209+
def tool_instance(tool_call, bot_user:, llm:, context:)
210+
211+
function_id = tool_call.id
212+
function_name = tool_call.name
219213
return nil if function_name.nil?
220214

221215
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@@ -224,7 +218,7 @@ def tool_instance(parsed_function, bot_user:, llm:, context:)
224218
arguments = {}
225219
tool_klass.signature[:parameters].to_a.each do |param|
226220
name = param[:name]
227-
value = parsed_function.at(name)&.text
221+
value = tool_call.parameters[name.to_sym]
228222

229223
if param[:type] == "array" && value
230224
value =

lib/completions/anthropic_message_processor.rb

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,87 +13,82 @@ def initialize(name, id)
1313
def append(json)
1414
@raw_json << json
1515
end
16+
17+
def to_tool_call
18+
parameters = JSON.parse(raw_json, symbolize_names: true)
19+
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
20+
end
1621
end
1722

1823
attr_reader :tool_calls, :input_tokens, :output_tokens
1924

2025
def initialize(streaming_mode:)
2126
@streaming_mode = streaming_mode
2227
@tool_calls = []
28+
@current_tool_call = nil
2329
end
2430

25-
def to_xml_tool_calls(function_buffer)
26-
return function_buffer if @tool_calls.blank?
27-
28-
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
29-
<function_calls>
30-
</function_calls>
31-
TEXT
32-
33-
@tool_calls.each do |tool_call|
34-
node =
35-
function_buffer.at("function_calls").add_child(
36-
Nokogiri::HTML5::DocumentFragment.parse(
37-
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
38-
),
39-
)
40-
41-
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
42-
xml =
43-
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
31+
def to_tool_calls
32+
@tool_calls.map { |tool_call| tool_call.to_tool_call }
33+
end
4434

45-
node.at("tool_name").content = tool_call.name
46-
node.at("tool_id").content = tool_call.id
47-
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
35+
def process_streamed_message(parsed)
36+
result = nil
37+
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
38+
tool_name = parsed.dig(:content_block, :name)
39+
tool_id = parsed.dig(:content_block, :id)
40+
if @current_tool_call
41+
result = @current_tool_call.to_tool_call
42+
end
43+
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
44+
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
45+
if @current_tool_call
46+
tool_delta = parsed.dig(:delta, :partial_json).to_s
47+
@current_tool_call.append(tool_delta)
48+
else
49+
result = parsed.dig(:delta, :text).to_s
50+
end
51+
elsif parsed[:type] == "content_block_stop"
52+
if @current_tool_call
53+
result = @current_tool_call.to_tool_call
54+
@current_tool_call = nil
55+
end
56+
elsif parsed[:type] == "message_start"
57+
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
58+
elsif parsed[:type] == "message_delta"
59+
@output_tokens =
60+
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
61+
elsif parsed[:type] == "message_stop"
62+
# bedrock has this ...
63+
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
64+
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
65+
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
66+
end
4867
end
49-
50-
function_buffer
68+
result
5169
end
5270

5371
def process_message(payload)
5472
result = ""
5573
parsed = JSON.parse(payload, symbolize_names: true)
5674

57-
if @streaming_mode
58-
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
59-
tool_name = parsed.dig(:content_block, :name)
60-
tool_id = parsed.dig(:content_block, :id)
61-
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
62-
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
63-
if @tool_calls.present?
64-
result = parsed.dig(:delta, :partial_json).to_s
65-
@tool_calls.last.append(result)
66-
else
67-
result = parsed.dig(:delta, :text).to_s
75+
content = parsed.dig(:content)
76+
if content.is_a?(Array)
77+
result =
78+
content.map do |data|
79+
if data[:type] == "tool_use"
80+
call = AnthropicToolCall.new(data[:name], data[:id])
81+
call.append(data[:input].to_json)
82+
call.to_tool_call
83+
else
84+
data[:text]
85+
end
6886
end
69-
elsif parsed[:type] == "message_start"
70-
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
71-
elsif parsed[:type] == "message_delta"
72-
@output_tokens =
73-
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
74-
elsif parsed[:type] == "message_stop"
75-
# bedrock has this ...
76-
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
77-
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
78-
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
79-
end
80-
end
81-
else
82-
content = parsed.dig(:content)
83-
if content.is_a?(Array)
84-
tool_call = content.find { |c| c[:type] == "tool_use" }
85-
if tool_call
86-
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
87-
@tool_calls.last.append(tool_call[:input].to_json)
88-
else
89-
result = parsed.dig(:content, 0, :text).to_s
90-
end
91-
end
92-
93-
@input_tokens = parsed.dig(:usage, :input_tokens)
94-
@output_tokens = parsed.dig(:usage, :output_tokens)
9587
end
9688

89+
@input_tokens = parsed.dig(:usage, :input_tokens)
90+
@output_tokens = parsed.dig(:usage, :output_tokens)
91+
9792
result
9893
end
9994
end

lib/completions/endpoints/anthropic.rb

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,18 @@ def prepare_request(payload)
9090
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
9191
end
9292

93+
def decode_chunk(partial_data)
94+
@decoder ||= JsonStreamDecoder.new
95+
(@decoder << partial_data).map do |parsed_json|
96+
processor.process_streamed_message(parsed_json)
97+
end.compact
98+
end
99+
93100
def processor
94101
@processor ||=
95102
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
96103
end
97104

98-
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
99-
processor.to_xml_tool_calls(function_buffer) if !partial
100-
end
101-
102105
def extract_completion_from(response_raw)
103106
processor.process_message(response_raw)
104107
end
@@ -107,6 +110,10 @@ def has_tool?(_response_data)
107110
processor.tool_calls.present?
108111
end
109112

113+
def tool_calls
114+
processor.to_tool_calls
115+
end
116+
110117
def final_log_update(log)
111118
log.request_tokens = processor.input_tokens if processor.input_tokens
112119
log.response_tokens = processor.output_tokens if processor.output_tokens

0 commit comments

Comments
 (0)