Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion app/controllers/discourse_ai/ai_helper/assistant_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def stream_suggestion
# otherwise we may end up streaming the data to the wrong client
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?

channel_id = next_channel_id
progress_channel = "discourse_ai_helper/stream_suggestions/#{channel_id}"

if location == "composer"
Jobs.enqueue(
:stream_composer_helper,
Expand All @@ -133,6 +136,7 @@ def stream_suggestion
custom_prompt: params[:custom_prompt],
force_default_locale: params[:force_default_locale] || false,
client_id: params[:client_id],
progress_channel:,
)
else
post_id = get_post_param!
Expand All @@ -148,10 +152,11 @@ def stream_suggestion
prompt: params[:mode],
custom_prompt: params[:custom_prompt],
client_id: params[:client_id],
progress_channel:,
)
end

render json: { success: true }, status: 200
render json: { success: true, progress_channel: }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
Expand Down Expand Up @@ -192,6 +197,18 @@ def caption_image

private

CHANNEL_ID_KEY = "discourse_ai_helper_next_channel_id"

def next_channel_id
Discourse
.redis
.pipelined do |pipeline|
pipeline.incr(CHANNEL_ID_KEY)
pipeline.expire(CHANNEL_ID_KEY, 1.day)
end
.first
end

def get_text_param!
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
end
Expand Down
3 changes: 2 additions & 1 deletion app/jobs/regular/stream_composer_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ def execute(args)
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]
return unless args[:client_id]
return unless args[:progress_channel]

helper_mode = args[:prompt]

DiscourseAi::AiHelper::Assistant.new.stream_prompt(
helper_mode,
args[:text],
user,
"/discourse-ai/ai-helper/stream_composer_suggestion",
args[:progress_channel],
force_default_locale: args[:force_default_locale],
client_id: args[:client_id],
custom_prompt: args[:custom_prompt],
Expand Down
5 changes: 4 additions & 1 deletion app/jobs/regular/stream_post_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def execute(args)
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]
return unless args[:progress_channel]
return unless args[:client_id]

topic = post.topic
reply_to = post.reply_to_post
Expand All @@ -31,8 +33,9 @@ def execute(args)
helper_mode,
input,
user,
"/discourse-ai/ai-helper/stream_suggestion/#{post.id}",
args[:progress_channel],
custom_prompt: args[:custom_prompt],
client_id: args[:client_id],
)
end
end
Expand Down
82 changes: 44 additions & 38 deletions assets/javascripts/discourse/components/ai-post-helper-menu.gjs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { modifier } from "ember-modifier";
Expand Down Expand Up @@ -43,9 +42,6 @@ export default class AiPostHelperMenu extends Component {
@tracked lastSelectedOption = null;
@tracked isSavingFootnote = false;
@tracked supportsAddFootnote = this.args.data.supportsFastEdit;
@tracked
channel =
`/discourse-ai/ai-helper/stream_suggestion/${this.args.data.quoteState.postId}`;

@tracked
smoothStreamer = new SmoothStreamer(
Expand Down Expand Up @@ -150,19 +146,25 @@ export default class AiPostHelperMenu extends Component {
return sanitize(text);
}

@bind
set progressChannel(value) {
if (this._progressChannel) {
this.unsubscribe();
}
this._progressChannel = value;
this.subscribe();
}

subscribe() {
this.messageBus.subscribe(
this.channel,
(data) => this._updateResult(data),
this.args.data.post
.discourse_ai_helper_stream_suggestion_last_message_bus_id
);
this.messageBus.subscribe(this._progressChannel, this._updateResult);
}

@bind
unsubscribe() {
this.messageBus.unsubscribe(this.channel, this._updateResult);
if (!this._progressChannel) {
return;
}
this.messageBus.unsubscribe(this._progressChannel, this._updateResult);
this._progressChannel = null;
}

@bind
Expand All @@ -182,32 +184,37 @@ export default class AiPostHelperMenu extends Component {
this.lastSelectedOption = option;
const streamableOptions = ["explain", "translate", "custom_prompt"];

if (streamableOptions.includes(option.name)) {
return this._handleStreamedResult(option);
} else {
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST",
data: {
mode: option.name,
text: this.args.data.quoteState.buffer,
custom_prompt: this.customPromptValue,
},
});
}
try {
if (streamableOptions.includes(option.name)) {
const streamedResult = await this._handleStreamedResult(option);
this.progressChannel = streamedResult.progress_channel;
return;
} else {
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST",
data: {
mode: option.name,
text: this.args.data.quoteState.buffer,
custom_prompt: this.customPromptValue,
},
});
}

this._activeAiRequest
.then(({ suggestions }) => {
this.suggestion = suggestions[0].trim();

if (option.name === "proofread") {
return this._handleProofreadOption();
}
})
.catch(popupAjaxError)
.finally(() => {
this.loading = false;
this.menuState = this.MENU_STATES.result;
});
this._activeAiRequest
.then(({ suggestions }) => {
this.suggestion = suggestions[0].trim();

if (option.name === "proofread") {
return this._handleProofreadOption();
}
})
.finally(() => {
this.loading = false;
this.menuState = this.MENU_STATES.result;
});
} catch (error) {
popupAjaxError(error);
}

return this._activeAiRequest;
}
Expand Down Expand Up @@ -340,7 +347,6 @@ export default class AiPostHelperMenu extends Component {
{{else if (eq this.menuState this.MENU_STATES.result)}}
<div
class="ai-post-helper__suggestion"
{{didInsert this.subscribe}}
{{willDestroy this.unsubscribe}}
>
{{#if this.suggestion}}
Expand Down
34 changes: 17 additions & 17 deletions assets/javascripts/discourse/components/modal/diff-modal.gjs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { htmlSafe } from "@ember/template";
Expand All @@ -19,8 +18,6 @@ import DiffStreamer from "../../lib/diff-streamer";
import SmoothStreamer from "../../lib/smooth-streamer";
import AiIndicatorWave from "../ai-indicator-wave";

const CHANNEL = "/discourse-ai/ai-helper/stream_composer_suggestion";

export default class ModalDiffModal extends Component {
@service currentUser;
@service messageBus;
Expand Down Expand Up @@ -83,21 +80,26 @@ export default class ModalDiffModal extends Component {
return this.loading || this.isStreaming;
}

@bind
set progressChannel(value) {
if (this._progressChannel) {
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
}
this._progressChannel = value;
this.subscribe();
}

subscribe() {
this.messageBus.subscribe(
CHANNEL,
this.updateResult,
this.currentUser
?.discourse_ai_helper_stream_composer_suggestion_last_message_bus_id
);
// we have 1 channel per operation so we can safely subscribe at head
this.messageBus.subscribe(this._progressChannel, this.updateResult, 0);
}

@bind
cleanup() {
// stop all callbacks so it does not end up streaming pointlessly
this.#resetState();
this.messageBus.unsubscribe(CHANNEL, this.updateResult);
if (this._progressChannel) {
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
}
}

@action
Expand All @@ -122,7 +124,7 @@ export default class ModalDiffModal extends Component {

try {
this.loading = true;
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
const result = await ajax("/discourse-ai/ai-helper/stream_suggestion", {
method: "POST",
data: {
location: "composer",
Expand All @@ -133,6 +135,8 @@ export default class ModalDiffModal extends Component {
client_id: this.messageBus.clientId,
},
});

this.progressChannel = result.progress_channel;
} catch (e) {
popupAjaxError(e);
}
Expand Down Expand Up @@ -183,11 +187,7 @@ export default class ModalDiffModal extends Component {
@closeModal={{this.cleanupAndClose}}
>
<:body>
<div
{{didInsert this.subscribe}}
{{willDestroy this.cleanup}}
class="text-preview"
>
<div {{willDestroy this.cleanup}} class="text-preview">
<div
class={{concatClass
"composer-ai-helper-modal__suggestion"
Expand Down
12 changes: 0 additions & 12 deletions lib/ai_helper/entry_point.rb
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ def inject_into(plugin)
scope.user.in_any_groups?(SiteSetting.ai_auto_image_caption_allowed_groups_map)
end,
) { object.auto_image_caption }

plugin.add_to_serializer(
:post,
:discourse_ai_helper_stream_suggestion_last_message_bus_id,
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_suggestion/#{object.id}") }

plugin.add_to_serializer(
:current_user,
:discourse_ai_helper_stream_composer_suggestion_last_message_bus_id,
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_composer_suggestion") }
end
end
end
Expand Down
24 changes: 19 additions & 5 deletions spec/jobs/regular/stream_composer_helper_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,33 @@
let(:mode) { DiscourseAi::AiHelper::Assistant::PROOFREAD }

it "does nothing if there is no user" do
channel = "/some/channel"
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
job.execute(user_id: nil, text: input, prompt: mode, force_default_locale: false)
MessageBus.track_publish(channel) do
job.execute(
user_id: nil,
text: input,
prompt: mode,
force_default_locale: false,
client_id: "123",
progress_channel: channel,
)
end

expect(messages).to be_empty
end

it "does nothing if there is no text" do
channel = "/some/channel"
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: nil,
prompt: mode,
force_default_locale: false,
client_id: "123",
progress_channel: channel,
)
end

Expand All @@ -47,16 +57,18 @@

it "publishes updates with a partial result" do
proofread_result = "I like to eat pie for breakfast because it is delicious."
channel = "/channel/123"

DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: input,
prompt: mode,
force_default_locale: true,
client_id: "123",
progress_channel: channel,
)
end

Expand All @@ -68,16 +80,18 @@

it "publishes a final update to signal we're done" do
proofread_result = "I like to eat pie for breakfast because it is delicious."
channel = "/channel/123"

DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: input,
prompt: mode,
force_default_locale: true,
client_id: "123",
progress_channel: channel,
)
end

Expand Down
Loading
Loading