Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 7300737

Browse files
committed
- halt after tools
- post streamer ensures we don't have half completed stuff on screen when a tool is slow - reimplemnt xml tools to have a more relaxed parse
1 parent 1eb1993 commit 7300737

File tree

6 files changed

+140
-37
lines changed

6 files changed

+140
-37
lines changed

lib/ai_bot/bot.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def reply(context, &update_blk)
106106
tool_found = false
107107
force_tool_if_needed(prompt, context)
108108

109+
tool_halted = false
110+
109111
result =
110112
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
111113
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
@@ -122,7 +124,12 @@ def reply(context, &update_blk)
122124
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
123125
tools_ran += 1
124126
ongoing_chain &&= tool.chain_next_response?
127+
128+
if !tool.chain_next_response?
129+
tool_halted = true
130+
end
125131
else
132+
next if tool_halted
126133
needs_newlines = true
127134
if partial.is_a?(DiscourseAi::Completions::ToolCall)
128135
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")

lib/ai_bot/playground.rb

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def reply_to(post, custom_instructions: nil, &blk)
399399
PostCustomPrompt.none
400400

401401
reply = +""
402-
start = Time.now
402+
post_streamer = nil
403403

404404
post_type =
405405
post.post_type == Post.types[:whisper] ? Post.types[:whisper] : Post.types[:regular]
@@ -448,6 +448,8 @@ def reply_to(post, custom_instructions: nil, &blk)
448448

449449
context[:skip_tool_details] ||= !bot.persona.class.tool_details
450450

451+
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
452+
451453
new_custom_prompts =
452454
bot.reply(context) do |partial, cancel, placeholder, type|
453455
reply << partial
@@ -461,22 +463,20 @@ def reply_to(post, custom_instructions: nil, &blk)
461463
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
462464
end
463465

464-
if stream_reply
465-
# Minor hack to skip the delay during tests.
466-
if placeholder.blank?
467-
next if (Time.now - start < 0.5) && !Rails.env.test?
468-
start = Time.now
469-
end
470-
471-
Discourse.redis.expire(redis_stream_key, 60)
472-
473-
publish_update(reply_post, { raw: raw })
466+
if post_streamer
467+
post_streamer.run_later {
468+
Discourse.redis.expire(redis_stream_key, 60)
469+
publish_update(reply_post, { raw: raw })
470+
}
474471
end
475472
end
476473

477474
return if reply.blank?
478475

479476
if stream_reply
477+
post_streamer.finish
478+
post_streamer = nil
479+
480480
# land the final message prior to saving so we don't clash
481481
reply_post.cooked = PrettyText.cook(reply)
482482
publish_final_update(reply_post)
@@ -514,6 +514,7 @@ def reply_to(post, custom_instructions: nil, &blk)
514514

515515
reply_post
516516
ensure
517+
post_streamer&.finish(skip_callback: true)
517518
publish_final_update(reply_post) if stream_reply
518519
if reply_post && post.post_number == 1 && post.topic.private_message?
519520
title_playground(reply_post)

lib/ai_bot/post_streamer.rb

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module AiBot
5+
class PostStreamer
6+
def initialize(delay: 0.5)
7+
@mutex = Mutex.new
8+
@callback = nil
9+
@delay = delay
10+
@done = false
11+
end
12+
13+
def run_later(&callback)
14+
@mutex.synchronize { @callback = callback }
15+
ensure_worker!
16+
end
17+
18+
def finish(skip_callback: false)
19+
@mutex.synchronize do
20+
@callback&.call if skip_callback
21+
@callback = nil
22+
@done = true
23+
end
24+
25+
begin
26+
@worker_thread&.wakeup
27+
rescue StandardError
28+
ThreadError
29+
end
30+
@worker_thread&.join
31+
@worker_thread = nil
32+
end
33+
34+
private
35+
36+
def run
37+
while !@done
38+
@mutex.synchronize do
39+
callback = @callback
40+
@callback = nil
41+
callback&.call
42+
end
43+
sleep @delay
44+
end
45+
end
46+
47+
def ensure_worker!
48+
return if @worker_thread
49+
@mutex.synchronize do
50+
return if @worker_thread
51+
db = RailsMultisite::ConnectionManagement.current_db
52+
@worker_thread =
53+
Thread.new { RailsMultisite::ConnectionManagement.with_connection(db) { run } }
54+
end
55+
end
56+
end
57+
end
58+
end

lib/ai_bot/tools/create_artifact.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def self.signature
3939
end
4040

4141
def invoke
42+
yield parameters[:name] || "Web Artifact"
4243
# Get the current post from context
4344
post = Post.find_by(id: context[:post_id])
4445
return error_response("No post context found") unless post

lib/completions/xml_tool_processor.rb

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,31 +62,14 @@ def <<(text)
6262
def finish
6363
return [] if @function_buffer.blank?
6464

65-
xml = Nokogiri::HTML5.fragment(@function_buffer)
66-
normalize_function_ids!(xml)
67-
last_invoke = xml.at("invoke:last")
68-
if last_invoke
69-
last_invoke.next_sibling.remove while last_invoke.next_sibling
70-
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
65+
idx = -1
66+
parse_malformed_xml(@function_buffer).map do |tool|
67+
ToolCall.new(
68+
id: "tool_#{idx += 1}",
69+
name: tool[:tool_name],
70+
parameters: tool[:parameters]
71+
)
7172
end
72-
73-
xml
74-
.css("invoke")
75-
.map do |invoke|
76-
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
77-
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
78-
parameters = {}
79-
invoke
80-
.at("parameters")
81-
&.children
82-
&.each do |node|
83-
next if node.text?
84-
name = node.name
85-
value = node.content.to_s
86-
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
87-
end
88-
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
89-
end
9073
end
9174

9275
def should_cancel?
@@ -95,6 +78,40 @@ def should_cancel?
9578

9679
private
9780

81+
def parse_malformed_xml(input)
82+
input
83+
.scan(
84+
%r{
85+
<invoke>
86+
\s*
87+
<tool_name>
88+
([^<]+)
89+
</tool_name>
90+
\s*
91+
<parameters>
92+
(.*?)
93+
</parameters>
94+
\s*
95+
</invoke>
96+
}mx,
97+
)
98+
.map do |tool_name, params|
99+
{
100+
tool_name: tool_name.strip,
101+
parameters:
102+
params
103+
.scan(%r{
104+
<([^>]+)>
105+
(.*?)
106+
</\1>
107+
}mx)
108+
.each_with_object({}) do |(name, value), hash|
109+
hash[name.to_sym] = value.gsub(/^<!\[CDATA\[|\]\]>$/, "")
110+
end,
111+
}
112+
end
113+
end
114+
98115
def normalize_function_ids!(function_buffer)
99116
function_buffer
100117
.css("invoke")

spec/lib/completions/xml_tool_processor_spec.rb

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,26 @@
1212
expect(processor.should_cancel?).to eq(false)
1313
end
1414

15+
it "can handle mix and match xml cause tool llms may not encode" do
16+
xml = (<<~XML).strip
17+
<function_calls>
18+
<invoke>
19+
<tool_name>hello</tool_name>
20+
<parameters>
21+
<hello>world <sam>sam</sam></hello>
22+
<test><![CDATA[</h1>\n</div>\n]]></test>
23+
</parameters>
24+
</invoke>
25+
XML
26+
27+
result = []
28+
result << (processor << xml)
29+
result << (processor.finish)
30+
31+
tool_call = result.last.first
32+
expect(tool_call.parameters).to eq(hello: "world <sam>sam</sam>", test: "</h1>\n</div>\n")
33+
end
34+
1535
it "is usable for simple single message mode" do
1636
xml = (<<~XML).strip
1737
hello
@@ -149,8 +169,7 @@
149169
result << (processor.finish)
150170

151171
# Should just do its best to parse the XML
152-
tool_call =
153-
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
172+
tool_call = DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: {})
154173
expect(result).to eq([["text"], [tool_call]])
155174
end
156175

0 commit comments

Comments
 (0)