Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def stream_suggestion
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
end

# to stream we must have an appropriate client_id
# otherwise we may end up streaming the data to the wrong client
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?

if location == "composer"
Jobs.enqueue(
:stream_composer_helper,
Expand All @@ -132,6 +136,7 @@ def stream_suggestion
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
force_default_locale: params[:force_default_locale] || false,
client_id: params[:client_id],
)
else
post_id = get_post_param!
Expand All @@ -146,6 +151,7 @@ def stream_suggestion
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
client_id: params[:client_id],
)
end

Expand Down
2 changes: 2 additions & 0 deletions app/jobs/regular/stream_composer_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def execute(args)
return unless args[:prompt]
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]
return unless args[:client_id]

prompt = CompletionPrompt.enabled_by_name(args[:prompt])

Expand All @@ -21,6 +22,7 @@ def execute(args)
user,
"/discourse-ai/ai-helper/stream_composer_suggestion",
force_default_locale: args[:force_default_locale],
client_id: args[:client_id],
)
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ export default class AiPostHelperMenu extends Component {
text: this.args.data.selectedText,
post_id: this.args.data.quoteState.postId,
custom_prompt: this.customPromptValue,
client_id: this.messageBus.clientId,
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export default class ModalDiffModal extends Component {
text: this.selectedText,
custom_prompt: this.args.model.customPromptValue,
force_default_locale: true,
client_id: this.messageBus.clientId,
},
});
} catch (e) {
Expand Down
10 changes: 5 additions & 5 deletions assets/javascripts/discourse/lib/diff-streamer.gjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { tracked } from "@glimmer/tracking";
import { later } from "@ember/runloop";
import { cancel, later } from "@ember/runloop";
import loadJSDiff from "discourse/lib/load-js-diff";
import { parseAsync } from "discourse/lib/text";

Expand Down Expand Up @@ -45,7 +45,7 @@ export default class DiffStreamer {
this.words = [];

if (this.typingTimer) {
clearTimeout(this.typingTimer);
cancel(this.typingTimer);
this.typingTimer = null;
}

Expand Down Expand Up @@ -100,7 +100,7 @@ export default class DiffStreamer {
this.currentCharIndex = 0;
this.isStreaming = false;
if (this.typingTimer) {
clearTimeout(this.typingTimer);
cancel(this.typingTimer);
this.typingTimer = null;
}
}
Expand Down Expand Up @@ -254,6 +254,8 @@ export default class DiffStreamer {

#formatDiffWithTags(diffArray, highlightLastWord = true) {
const wordsWithType = [];
const output = [];

diffArray.forEach((part) => {
const tokens = part.value.match(/\S+|\s+/g) || [];
tokens.forEach((token) => {
Expand All @@ -277,8 +279,6 @@ export default class DiffStreamer {
}
}

const output = [];

for (let i = 0; i <= lastWordIndex; i++) {
const { text, type } = wordsWithType[i];

Expand Down
36 changes: 30 additions & 6 deletions lib/ai_helper/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def generate_and_send_prompt(completion_prompt, input, user, force_default_local
result
end

def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
def stream_prompt(
completion_prompt,
input,
user,
channel,
force_default_locale: false,
client_id: nil
)
streamed_diff = +""
streamed_result = +""
start = Time.now
Expand All @@ -178,15 +185,14 @@ def stream_prompt(completion_prompt, input, user, channel, force_default_locale:
force_default_locale: force_default_locale,
) do |partial_response, cancel_function|
streamed_result << partial_response

streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?

# Throttle updates and check for safe stream points
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
sanitized = sanitize_result(streamed_result)

payload = { result: sanitized, diff: streamed_diff, done: false }
publish_update(channel, payload, user)
publish_update(channel, payload, user, client_id: client_id)
start = Time.now
end
end
Expand All @@ -195,7 +201,12 @@ def stream_prompt(completion_prompt, input, user, channel, force_default_locale:

sanitized_result = sanitize_result(streamed_result)
if sanitized_result.present?
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
publish_update(
channel,
{ result: sanitized_result, diff: final_diff, done: true },
user,
client_id: client_id,
)
end
end

Expand Down Expand Up @@ -238,8 +249,21 @@ def sanitize_result(result)
result.gsub(SANITIZE_REGEX, "")
end

def publish_update(channel, payload, user)
MessageBus.publish(channel, payload, user_ids: [user.id])
def publish_update(channel, payload, user, client_id: nil)
# when publishing we make sure we do not keep large backlogs on the channel
# and make sure we clear the streaming info after 60 seconds
# this ensures we do not bloat redis
if client_id
MessageBus.publish(
channel,
payload,
user_ids: [user.id],
client_ids: [client_id],
max_backlog_age: 60,
)
else
MessageBus.publish(channel, payload, user_ids: [user.id], max_backlog_age: 60)
end
end

def icon_map(name)
Expand Down
3 changes: 3 additions & 0 deletions spec/jobs/regular/stream_composer_helper_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
text: nil,
prompt: prompt.name,
force_default_locale: false,
client_id: "123",
)
end

Expand All @@ -58,6 +59,7 @@
text: input,
prompt: prompt.name,
force_default_locale: true,
client_id: "123",
)
end

Expand All @@ -78,6 +80,7 @@
text: input,
prompt: prompt.name,
force_default_locale: true,
client_id: "123",
)
end

Expand Down
45 changes: 36 additions & 9 deletions spec/requests/ai_helper/assistant_controller_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@

RSpec.describe DiscourseAi::AiHelper::AssistantController do
before { assign_fake_provider_to(:ai_helper_model) }
fab!(:newuser)
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }

describe "#stream_suggestion" do
before do
Jobs.run_immediately!
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
end

it "is able to stream suggestions back on appropriate channel" do
sign_in(user)
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
results = [["hello ", "world"]]
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
post "/discourse-ai/ai-helper/stream_suggestion.json",
params: {
text: "hello wrld",
location: "composer",
client_id: "1234",
mode: CompletionPrompt::PROOFREAD,
}

expect(response.status).to eq(200)
end
end

last_message = messages.last
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)

expect(last_message.data[:result]).to eq("hello world")
expect(last_message.data[:done]).to eq(true)
end
end

describe "#suggest" do
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
Expand All @@ -17,10 +52,8 @@
end

context "when logged in as an user without enough privileges" do
fab!(:user) { Fabricate(:newuser) }

before do
sign_in(user)
sign_in(newuser)
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff]
end

Expand All @@ -32,8 +65,6 @@
end

context "when logged in as an allowed user" do
fab!(:user)

before do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
Expand Down Expand Up @@ -141,8 +172,6 @@
fab!(:post_2) { Fabricate(:post, topic: topic, raw: "I love bananas") }

context "when logged in as an allowed user" do
fab!(:user)

before do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
Expand Down Expand Up @@ -219,8 +248,6 @@ def request_caption(params, caption = "A picture of a cat sitting on a table")
end

context "when logged in as an allowed user" do
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }

before do
sign_in(user)

Expand Down
Loading