Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit cf220c5

Browse files
authored
FIX: Improve MessageBus efficiency and correctly stop streaming (#1362)
* FIX: Improve MessageBus efficiency and correctly stop streaming This commit enhances the message bus implementation for AI helper streaming by: - Adding client_id targeting for message bus publications to ensure only the requesting client receives streaming updates - Limiting MessageBus backlog size (2) and age (60 seconds) to prevent Redis bloat - Replacing clearTimeout with Ember's cancel method for proper runloop management, we were leaking a stop - Adding tests for client-specific message delivery These changes improve memory usage and make streaming more reliable by ensuring messages are properly directed to the requesting client. * composer suggestion needed a fix as well. * backlog size of 2 is risky here cause same channel name is reused between clients
1 parent 61ef193 commit cf220c5

File tree

8 files changed

+84
-20
lines changed

8 files changed

+84
-20
lines changed

app/controllers/discourse_ai/ai_helper/assistant_controller.rb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def stream_suggestion
124124
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
125125
end
126126

127+
# to stream we must have an appropriate client_id
128+
# otherwise we may end up streaming the data to the wrong client
129+
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?
130+
127131
if location == "composer"
128132
Jobs.enqueue(
129133
:stream_composer_helper,
@@ -132,6 +136,7 @@ def stream_suggestion
132136
prompt: prompt.name,
133137
custom_prompt: params[:custom_prompt],
134138
force_default_locale: params[:force_default_locale] || false,
139+
client_id: params[:client_id],
135140
)
136141
else
137142
post_id = get_post_param!
@@ -146,6 +151,7 @@ def stream_suggestion
146151
text: text,
147152
prompt: prompt.name,
148153
custom_prompt: params[:custom_prompt],
154+
client_id: params[:client_id],
149155
)
150156
end
151157

app/jobs/regular/stream_composer_helper.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def execute(args)
88
return unless args[:prompt]
99
return unless user = User.find_by(id: args[:user_id])
1010
return unless args[:text]
11+
return unless args[:client_id]
1112

1213
prompt = CompletionPrompt.enabled_by_name(args[:prompt])
1314

@@ -21,6 +22,7 @@ def execute(args)
2122
user,
2223
"/discourse-ai/ai-helper/stream_composer_suggestion",
2324
force_default_locale: args[:force_default_locale],
25+
client_id: args[:client_id],
2426
)
2527
end
2628
end

assets/javascripts/discourse/components/ai-post-helper-menu.gjs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ export default class AiPostHelperMenu extends Component {
242242
text: this.args.data.selectedText,
243243
post_id: this.args.data.quoteState.postId,
244244
custom_prompt: this.customPromptValue,
245+
client_id: this.messageBus.clientId,
245246
},
246247
});
247248

assets/javascripts/discourse/components/modal/diff-modal.gjs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ export default class ModalDiffModal extends Component {
108108
text: this.selectedText,
109109
custom_prompt: this.args.model.customPromptValue,
110110
force_default_locale: true,
111+
client_id: this.messageBus.clientId,
111112
},
112113
});
113114
} catch (e) {

assets/javascripts/discourse/lib/diff-streamer.gjs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { tracked } from "@glimmer/tracking";
2-
import { later } from "@ember/runloop";
2+
import { cancel, later } from "@ember/runloop";
33
import loadJSDiff from "discourse/lib/load-js-diff";
44
import { parseAsync } from "discourse/lib/text";
55

@@ -45,7 +45,7 @@ export default class DiffStreamer {
4545
this.words = [];
4646

4747
if (this.typingTimer) {
48-
clearTimeout(this.typingTimer);
48+
cancel(this.typingTimer);
4949
this.typingTimer = null;
5050
}
5151

@@ -100,7 +100,7 @@ export default class DiffStreamer {
100100
this.currentCharIndex = 0;
101101
this.isStreaming = false;
102102
if (this.typingTimer) {
103-
clearTimeout(this.typingTimer);
103+
cancel(this.typingTimer);
104104
this.typingTimer = null;
105105
}
106106
}
@@ -254,6 +254,8 @@ export default class DiffStreamer {
254254

255255
#formatDiffWithTags(diffArray, highlightLastWord = true) {
256256
const wordsWithType = [];
257+
const output = [];
258+
257259
diffArray.forEach((part) => {
258260
const tokens = part.value.match(/\S+|\s+/g) || [];
259261
tokens.forEach((token) => {
@@ -277,8 +279,6 @@ export default class DiffStreamer {
277279
}
278280
}
279281

280-
const output = [];
281-
282282
for (let i = 0; i <= lastWordIndex; i++) {
283283
const { text, type } = wordsWithType[i];
284284

lib/ai_helper/assistant.rb

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def generate_and_send_prompt(completion_prompt, input, user, force_default_local
166166
result
167167
end
168168

169-
def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
169+
def stream_prompt(
170+
completion_prompt,
171+
input,
172+
user,
173+
channel,
174+
force_default_locale: false,
175+
client_id: nil
176+
)
170177
streamed_diff = +""
171178
streamed_result = +""
172179
start = Time.now
@@ -178,15 +185,14 @@ def stream_prompt(completion_prompt, input, user, channel, force_default_locale:
178185
force_default_locale: force_default_locale,
179186
) do |partial_response, cancel_function|
180187
streamed_result << partial_response
181-
182188
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
183189

184190
# Throttle updates and check for safe stream points
185191
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
186192
sanitized = sanitize_result(streamed_result)
187193

188194
payload = { result: sanitized, diff: streamed_diff, done: false }
189-
publish_update(channel, payload, user)
195+
publish_update(channel, payload, user, client_id: client_id)
190196
start = Time.now
191197
end
192198
end
@@ -195,7 +201,12 @@ def stream_prompt(completion_prompt, input, user, channel, force_default_locale:
195201

196202
sanitized_result = sanitize_result(streamed_result)
197203
if sanitized_result.present?
198-
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
204+
publish_update(
205+
channel,
206+
{ result: sanitized_result, diff: final_diff, done: true },
207+
user,
208+
client_id: client_id,
209+
)
199210
end
200211
end
201212

@@ -238,8 +249,21 @@ def sanitize_result(result)
238249
result.gsub(SANITIZE_REGEX, "")
239250
end
240251

241-
def publish_update(channel, payload, user)
242-
MessageBus.publish(channel, payload, user_ids: [user.id])
252+
def publish_update(channel, payload, user, client_id: nil)
253+
# when publishing we make sure we do not keep large backlogs on the channel
254+
# and make sure we clear the streaming info after 60 seconds
255+
# this ensures we do not bloat redis
256+
if client_id
257+
MessageBus.publish(
258+
channel,
259+
payload,
260+
user_ids: [user.id],
261+
client_ids: [client_id],
262+
max_backlog_age: 60,
263+
)
264+
else
265+
MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60)
266+
end
243267
end
244268

245269
def icon_map(name)

spec/jobs/regular/stream_composer_helper_spec.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
text: nil,
3636
prompt: prompt.name,
3737
force_default_locale: false,
38+
client_id: "123",
3839
)
3940
end
4041

@@ -58,6 +59,7 @@
5859
text: input,
5960
prompt: prompt.name,
6061
force_default_locale: true,
62+
client_id: "123",
6163
)
6264
end
6365

@@ -78,6 +80,7 @@
7880
text: input,
7981
prompt: prompt.name,
8082
force_default_locale: true,
83+
client_id: "123",
8184
)
8285
end
8386

spec/requests/ai_helper/assistant_controller_spec.rb

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,41 @@
22

33
RSpec.describe DiscourseAi::AiHelper::AssistantController do
44
before { assign_fake_provider_to(:ai_helper_model) }
5+
fab!(:newuser)
6+
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }
7+
8+
describe "#stream_suggestion" do
9+
before do
10+
Jobs.run_immediately!
11+
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
12+
end
13+
14+
it "is able to stream suggestions back on appropriate channel" do
15+
sign_in(user)
16+
messages =
17+
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
18+
results = [["hello ", "world"]]
19+
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
20+
post "/discourse-ai/ai-helper/stream_suggestion.json",
21+
params: {
22+
text: "hello wrld",
23+
location: "composer",
24+
client_id: "1234",
25+
mode: CompletionPrompt::PROOFREAD,
26+
}
27+
28+
expect(response.status).to eq(200)
29+
end
30+
end
31+
32+
last_message = messages.last
33+
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
34+
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
35+
36+
expect(last_message.data[:result]).to eq("hello world")
37+
expect(last_message.data[:done]).to eq(true)
38+
end
39+
end
540

641
describe "#suggest" do
742
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
@@ -17,10 +52,8 @@
1752
end
1853

1954
context "when logged in as an user without enough privileges" do
20-
fab!(:user) { Fabricate(:newuser) }
21-
2255
before do
23-
sign_in(user)
56+
sign_in(newuser)
2457
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff]
2558
end
2659

@@ -32,8 +65,6 @@
3265
end
3366

3467
context "when logged in as an allowed user" do
35-
fab!(:user)
36-
3768
before do
3869
sign_in(user)
3970
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
@@ -141,8 +172,6 @@
141172
fab!(:post_2) { Fabricate(:post, topic: topic, raw: "I love bananas") }
142173

143174
context "when logged in as an allowed user" do
144-
fab!(:user)
145-
146175
before do
147176
sign_in(user)
148177
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
@@ -219,8 +248,6 @@ def request_caption(params, caption = "A picture of a cat sitting on a table")
219248
end
220249

221250
context "when logged in as an allowed user" do
222-
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }
223-
224251
before do
225252
sign_in(user)
226253

0 commit comments

Comments
 (0)