Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit c06a260

Browse files
committed
streamline cancel logic
1 parent e17e4f1 commit c06a260

File tree

7 files changed

+63
-70
lines changed

7 files changed

+63
-70
lines changed

lib/ai_bot/chat_streamer.rb

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@
66
module DiscourseAi
77
module AiBot
88
class ChatStreamer
9-
attr_accessor :cancel
109
attr_reader :reply,
1110
:guardian,
1211
:thread_id,
1312
:force_thread,
1413
:in_reply_to_id,
1514
:channel,
16-
:cancelled
17-
18-
def initialize(message:, channel:, guardian:, thread_id:, in_reply_to_id:, force_thread:)
15+
:cancel_manager
16+
17+
def initialize(
18+
message:,
19+
channel:,
20+
guardian:,
21+
thread_id:,
22+
in_reply_to_id:,
23+
force_thread:,
24+
cancel_manager: nil
25+
)
1926
@message = message
2027
@channel = channel
2128
@guardian = guardian
@@ -35,6 +42,8 @@ def initialize(message:, channel:, guardian:, thread_id:, in_reply_to_id:, force
3542
guardian: guardian,
3643
thread_id: thread_id,
3744
)
45+
46+
@cancel_manager = cancel_manager
3847
end
3948

4049
def <<(partial)
@@ -111,8 +120,7 @@ def run
111120

112121
streaming = ChatSDK::Message.stream(message_id: reply.id, raw: buffer, guardian: guardian)
113122
if !streaming
114-
cancel.call
115-
@cancelled = true
123+
@cancel_manager.cancel! if @cancel_manager
116124
end
117125
end
118126
end

lib/ai_bot/playground.rb

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def reply_to_chat_message(message, channel, context_post_ids)
331331
),
332332
user: message.user,
333333
skip_tool_details: true,
334+
cancel_manager: DiscourseAi::Completions::CancelManager.new,
334335
)
335336

336337
reply = nil
@@ -347,15 +348,14 @@ def reply_to_chat_message(message, channel, context_post_ids)
347348
thread_id: message.thread_id,
348349
in_reply_to_id: in_reply_to_id,
349350
force_thread: force_thread,
351+
cancel_manager: context.cancel_manager,
350352
)
351353

352354
new_prompts =
353-
bot.reply(context) do |partial, cancel, placeholder, type|
355+
bot.reply(context) do |partial, placeholder, type|
354356
# no support for tools or thinking by design
355357
next if type == :thinking || type == :tool_details || type == :partial_tool
356-
streamer.cancel = cancel
357358
streamer << partial
358-
break if streamer.cancelled
359359
end
360360

361361
reply = streamer.reply
@@ -383,6 +383,7 @@ def reply_to(
383383
auto_set_title: true,
384384
silent_mode: false,
385385
feature_name: nil,
386+
cancel_manager: nil,
386387
&blk
387388
)
388389
# this is a multithreading issue
@@ -472,22 +473,25 @@ def reply_to(
472473
redis_stream_key = "gpt_cancel:#{reply_post.id}"
473474
Discourse.redis.setex(redis_stream_key, MAX_STREAM_DELAY_SECONDS, 1)
474475

475-
context.cancel_manager = DiscourseAi::Completions::CancelManager.new
476+
cancel_manager ||= DiscourseAi::Completions::CancelManager.new
477+
context.cancel_manager = cancel_manager
476478
context
477479
.cancel_manager
478480
.start_monitor(delay: 0.2) do
479481
context.cancel_manager.cancel! if !Discourse.redis.get(redis_stream_key)
480482
end
483+
484+
context.cancel_manager.add_callback(
485+
lambda { reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) },
486+
)
481487
end
482488

483489
context.skip_tool_details ||= !bot.persona.class.tool_details
484-
485490
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
486-
487491
started_thinking = false
488492

489493
new_custom_prompts =
490-
bot.reply(context) do |partial, cancel, placeholder, type|
494+
bot.reply(context) do |partial, placeholder, type|
491495
if type == :thinking && !started_thinking
492496
reply << "<details><summary>#{I18n.t("discourse_ai.ai_bot.thinking")}</summary>"
493497
started_thinking = true
@@ -506,15 +510,6 @@ def reply_to(
506510
blk.call(partial)
507511
end
508512

509-
if stream_reply && !Discourse.redis.get(redis_stream_key)
510-
cancel&.call
511-
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
512-
# we do not break out, cause if we do
513-
# we will not get results from bot
514-
# leading to broken context
515-
# we need to trust it to cancel at the endpoint
516-
end
517-
518513
if post_streamer
519514
post_streamer.run_later do
520515
Discourse.redis.expire(redis_stream_key, MAX_STREAM_DELAY_SECONDS)

lib/completions/endpoints/base.rb

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,9 @@ def perform_completion!(
101101

102102
wrapped = result
103103
wrapped = [result] if !result.is_a?(Array)
104-
cancelled_by_caller = false
105-
cancel_proc = -> { cancelled_by_caller = true }
106104
wrapped.each do |partial|
107-
blk.call(partial, cancel_proc)
108-
break if cancelled_by_caller
105+
blk.call(partial)
106+
break cancel_manager&.cancelled?
109107
end
110108
return result
111109
end
@@ -176,7 +174,7 @@ def perform_completion!(
176174

177175
if @streaming_mode
178176
blk =
179-
lambda do |partial, cancel|
177+
lambda do |partial|
180178
if partial.is_a?(String)
181179
partial = xml_stripper << partial if xml_stripper
182180

@@ -185,7 +183,7 @@ def perform_completion!(
185183
partial = structured_output
186184
end
187185
end
188-
orig_blk.call(partial, cancel) if partial
186+
orig_blk.call(partial) if partial
189187
end
190188
end
191189

@@ -214,13 +212,6 @@ def perform_completion!(
214212
end
215213

216214
begin
217-
cancel = -> do
218-
cancelled = true
219-
http.finish
220-
end
221-
222-
break if cancelled
223-
224215
response.read_body do |chunk|
225216
break if cancelled
226217

@@ -233,12 +224,9 @@ def perform_completion!(
233224
partials = [partial]
234225
if xml_tool_processor && partial.is_a?(String)
235226
partials = (xml_tool_processor << partial)
236-
if xml_tool_processor.should_cancel?
237-
cancel.call
238-
break
239-
end
227+
break if xml_tool_processor.should_cancel?
240228
end
241-
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
229+
partials.each { |inner_partial| blk.call(inner_partial) }
242230
end
243231
end
244232
end
@@ -248,13 +236,11 @@ def perform_completion!(
248236
response_data << stripped
249237
result = []
250238
result = (xml_tool_processor << stripped) if xml_tool_processor
251-
result.each { |partial| blk.call(partial, cancel) }
239+
result.each { |partial| blk.call(partial) }
252240
end
253241
end
254-
if xml_tool_processor
255-
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
256-
end
257-
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
242+
xml_tool_processor.finish.each { |partial| blk.call(partial) } if xml_tool_processor
243+
decode_chunk_finish.each { |partial| blk.call(partial) }
258244
return response_data
259245
ensure
260246
if log

lib/personas/bot.rb

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def reply(context, llm_args: {}, &update_blk)
9494
output_thinking: true,
9595
cancel_manager: context.cancel_manager,
9696
**llm_kwargs,
97-
) do |partial, cancel|
97+
) do |partial|
9898
tool =
9999
persona.find_tool(
100100
partial,
@@ -111,23 +111,22 @@ def reply(context, llm_args: {}, &update_blk)
111111
if tool_call.partial?
112112
if tool.class.allow_partial_tool_calls?
113113
tool.partial_invoke
114-
update_blk.call("", cancel, tool.custom_raw, :partial_tool)
114+
update_blk.call("", tool.custom_raw, :partial_tool)
115115
end
116116
next
117117
end
118118

119119
tool_found = true
120120
# a bit hacky, but extra newlines do no harm
121121
if needs_newlines
122-
update_blk.call("\n\n", cancel)
122+
update_blk.call("\n\n")
123123
needs_newlines = false
124124
end
125125

126126
process_tool(
127127
tool: tool,
128128
raw_context: raw_context,
129129
current_llm: current_llm,
130-
cancel: cancel,
131130
update_blk: update_blk,
132131
prompt: prompt,
133132
context: context,
@@ -146,17 +145,17 @@ def reply(context, llm_args: {}, &update_blk)
146145
else
147146
if partial.is_a?(DiscourseAi::Completions::Thinking)
148147
if partial.partial? && partial.message.present?
149-
update_blk.call(partial.message, cancel, nil, :thinking)
148+
update_blk.call(partial.message, nil, :thinking)
150149
end
151150
if !partial.partial?
152151
# this will be dealt with later
153152
raw_context << partial
154153
current_thinking << partial
155154
end
156155
elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
157-
update_blk.call(partial, cancel, nil, :structured_output)
156+
update_blk.call(partial, nil, :structured_output)
158157
else
159-
update_blk.call(partial, cancel)
158+
update_blk.call(partial)
160159
end
161160
end
162161
end
@@ -217,14 +216,13 @@ def process_tool(
217216
tool:,
218217
raw_context:,
219218
current_llm:,
220-
cancel:,
221219
update_blk:,
222220
prompt:,
223221
context:,
224222
current_thinking:
225223
)
226224
tool_call_id = tool.tool_call_id
227-
invocation_result_json = invoke_tool(tool, cancel, context, &update_blk).to_json
225+
invocation_result_json = invoke_tool(tool, context, &update_blk).to_json
228226

229227
tool_call_message = {
230228
type: :tool_call,
@@ -258,27 +256,27 @@ def process_tool(
258256
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
259257
end
260258

261-
def invoke_tool(tool, cancel, context, &update_blk)
259+
def invoke_tool(tool, context, &update_blk)
262260
show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
263261

264-
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
262+
update_blk.call("", build_placeholder(tool.summary, "")) if show_placeholder
265263

266264
result =
267265
tool.invoke do |progress, render_raw|
268266
if render_raw
269-
update_blk.call("", cancel, tool.custom_raw, :partial_invoke)
267+
update_blk.call("", tool.custom_raw, :partial_invoke)
270268
show_placeholder = false
271269
elsif show_placeholder
272270
placeholder = build_placeholder(tool.summary, progress)
273-
update_blk.call("", cancel, placeholder)
271+
update_blk.call("", placeholder)
274272
end
275273
end
276274

277275
if show_placeholder
278276
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
279-
update_blk.call(tool_details, cancel, nil, :tool_details)
277+
update_blk.call(tool_details, nil, :tool_details)
280278
elsif tool.custom_raw.present?
281-
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
279+
update_blk.call(tool.custom_raw, nil, :custom_raw)
282280
end
283281

284282
result

lib/personas/bot_context.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def initialize(
3434
channel_id: nil,
3535
context_post_ids: nil,
3636
feature_name: "bot",
37-
resource_url: nil
37+
resource_url: nil,
38+
cancel_manager: nil
3839
)
3940
@participants = participants
4041
@user = user
@@ -55,6 +56,8 @@ def initialize(
5556
@feature_name = feature_name
5657
@resource_url = resource_url
5758

59+
@cancel_manager = cancel_manager
60+
5861
if post
5962
@post_id = post.id
6063
@topic_id = post.topic_id

spec/lib/completions/endpoints/endpoint_compliance.rb

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,11 @@ def streaming_mode_simple_prompt(mock)
188188
mock.stub_streamed_simple_call(dialect.translate) do
189189
completion_response = +""
190190

191-
endpoint.perform_completion!(dialect, user) do |partial, cancel|
191+
cancel_manager = DiscourseAi::Completions::CancelManager.new
192+
193+
endpoint.perform_completion!(dialect, user, cancel_manager: cancel_manager) do |partial|
192194
completion_response << partial
193-
cancel.call if completion_response.split(" ").length == 2
195+
cancel_manager.cancel! if completion_response.split(" ").length == 2
194196
end
195197

196198
expect(AiApiAuditLog.count).to eq(1)
@@ -212,12 +214,14 @@ def streaming_mode_tools(mock)
212214
prompt = generic_prompt(tools: [mock.tool])
213215
a_dialect = dialect(prompt: prompt)
214216

217+
cancel_manager = DiscourseAi::Completions::CancelManager.new
218+
215219
mock.stub_streamed_tool_call(a_dialect.translate) do
216220
buffered_partial = []
217221

218-
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
222+
endpoint.perform_completion!(a_dialect, user, cancel_manager: cancel_manager) do |partial|
219223
buffered_partial << partial
220-
cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
224+
cancel_manager if partial.is_a?(DiscourseAi::Completions::ToolCall)
221225
end
222226

223227
expect(buffered_partial).to eq([mock.invocation_response])

spec/lib/modules/ai_bot/playground_spec.rb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,14 +1136,13 @@
11361136

11371137
split = body.split("|")
11381138

1139+
cancel_manager = DiscourseAi::Completions::CancelManager.new
1140+
11391141
count = 0
11401142
DiscourseAi::AiBot::PostStreamer.on_callback =
11411143
proc do |callback|
11421144
count += 1
1143-
if count == 2
1144-
last_post = third_post.topic.posts.order(:id).last
1145-
Discourse.redis.del("gpt_cancel:#{last_post.id}")
1146-
end
1145+
cancel_manager.cancel! if count == 2
11471146
raise "this should not happen" if count > 2
11481147
end
11491148

@@ -1155,7 +1154,7 @@
11551154
)
11561155
# we are going to need to use real data here cause we want to trigger the
11571156
# base endpoint to cancel part way through
1158-
playground.reply_to(third_post)
1157+
playground.reply_to(third_post, cancel_manager: cancel_manager)
11591158
end
11601159

11611160
last_post = third_post.topic.posts.order(:id).last

0 commit comments

Comments
 (0)