Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions app/controllers/discourse_ai/ai_helper/assistant_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,44 @@ def suggest_thumbnails(input)
end

def stream_suggestion
post_id = get_post_param!
text = get_text_param!
post = Post.includes(:topic).find_by(id: post_id)

location = params[:location]
raise Discourse::InvalidParameters.new(:location) if !location

prompt = CompletionPrompt.find_by(id: params[:mode])

raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
raise Discourse::InvalidParameters.new(:post_id) unless post
return suggest_thumbnails(input) if prompt.id == CompletionPrompt::ILLUSTRATE_POST

if prompt.id == CompletionPrompt::CUSTOM_PROMPT
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
end

Jobs.enqueue(
:stream_post_helper,
post_id: post.id,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
)
if location == "composer"
Jobs.enqueue(
:stream_composer_helper,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
force_default_locale: params[:force_default_locale] || false,
)
else
post_id = get_post_param!
post = Post.includes(:topic).find_by(id: post_id)

raise Discourse::InvalidParameters.new(:post_id) unless post

Jobs.enqueue(
:stream_post_helper,
post_id: post.id,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
)
end

render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
Expand Down
27 changes: 27 additions & 0 deletions app/jobs/regular/stream_composer_helper.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# frozen_string_literal: true

module Jobs
class StreamComposerHelper < ::Jobs::Base
sidekiq_options retry: false

def execute(args)
return unless args[:prompt]
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]

prompt = CompletionPrompt.enabled_by_name(args[:prompt])

if prompt.id == CompletionPrompt::CUSTOM_PROMPT
prompt.custom_instruction = args[:custom_prompt]
end

DiscourseAi::AiHelper::Assistant.new.stream_prompt(
prompt,
args[:text],
user,
"/discourse-ai/ai-helper/stream_composer_suggestion",
force_default_locale: args[:force_default_locale],
)
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ export default class AiPostHelperMenu extends Component {
this._activeAiRequest = ajax(fetchUrl, {
method: "POST",
data: {
location: "post",
mode: option.id,
text: this.args.data.selectedText,
post_id: this.args.data.quoteState.postId,
Expand Down
105 changes: 84 additions & 21 deletions assets/javascripts/discourse/components/modal/diff-modal.gjs
Original file line number Diff line number Diff line change
@@ -1,49 +1,94 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { htmlSafe } from "@ember/template";
import CookText from "discourse/components/cook-text";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import concatClass from "discourse/helpers/concat-class";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { bind } from "discourse/lib/decorators";
import { i18n } from "discourse-i18n";
import SmoothStreamer from "../../lib/smooth-streamer";
import AiIndicatorWave from "../ai-indicator-wave";

export default class ModalDiffModal extends Component {
@service currentUser;
@service messageBus;

@tracked loading = false;
@tracked diff;
@tracked suggestion = "";
@tracked
smoothStreamer = new SmoothStreamer(
() => this.suggestion,
(newValue) => (this.suggestion = newValue)
);

constructor() {
super(...arguments);
this.suggestChanges();
}

@bind
subscribe() {
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
this.messageBus.subscribe(channel, this.updateResult);
}

@bind
unsubscribe() {
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
this.messageBus.subscribe(channel, this.updateResult);
}

@action
async updateResult(result) {
if (result) {
this.loading = false;
}
await this.smoothStreamer.updateResult(result, "result");

if (result.done) {
this.diff = result.diff;
}

const mdTablePromptId = this.currentUser?.ai_helper_prompts.find(
(prompt) => prompt.name === "markdown_table"
).id;

// Markdown table prompt looks better with
// before/after results than diff
// despite having `type: diff`
if (this.args.model.mode === mdTablePromptId) {
this.diff = null;
}
}

@action
async suggestChanges() {
this.smoothStreamer.resetStreaming();
this.diff = null;
this.suggestion = "";
this.loading = true;

try {
const suggestion = await ajax("/discourse-ai/ai-helper/suggest", {
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
method: "POST",
data: {
location: "composer",
mode: this.args.model.mode,
text: this.args.model.selectedText,
custom_prompt: this.args.model.customPromptValue,
force_default_locale: true,
},
});

this.diff = suggestion.diff;
this.suggestion = suggestion.suggestions[0];
} catch (e) {
popupAjaxError(e);
} finally {
this.loading = false;
}
}

Expand All @@ -66,24 +111,42 @@ export default class ModalDiffModal extends Component {
@closeModal={{@closeModal}}
>
<:body>
{{#if this.loading}}
<div class="composer-ai-helper-modal__loading">
<CookText @rawText={{@model.selectedText}} />
</div>
{{else}}
{{#if this.diff}}
{{htmlSafe this.diff}}
{{else}}
<div class="composer-ai-helper-modal__old-value">
{{@model.selectedText}}
<div {{didInsert this.subscribe}} {{willDestroy this.unsubscribe}}>
{{#if this.loading}}
<div class="composer-ai-helper-modal__loading">
<CookText @rawText={{@model.selectedText}} />
</div>

<div class="composer-ai-helper-modal__new-value">
{{this.suggestion}}
{{else}}
<div
class={{concatClass
"composer-ai-helper-modal__suggestion"
"streamable-content"
(if this.smoothStreamer.isStreaming "streaming" "")
}}
>
{{#if this.smoothStreamer.isStreaming}}
<CookText
@rawText={{this.smoothStreamer.renderedText}}
class="cooked"
/>
{{else}}
{{#if this.diff}}
{{htmlSafe this.diff}}
{{else}}
<div class="composer-ai-helper-modal__old-value">
{{@model.selectedText}}
</div>
<div class="composer-ai-helper-modal__new-value">
<CookText
@rawText={{this.smoothStreamer.renderedText}}
class="cooked"
/>
</div>
{{/if}}
{{/if}}
</div>
{{/if}}
{{/if}}

</div>
</:body>

<:footer>
Expand Down
42 changes: 31 additions & 11 deletions lib/ai_helper/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def custom_locale_instructions(user = nil, force_default_locale)
end
end

def localize_prompt!(prompt, user = nil, force_default_locale = false)
def localize_prompt!(prompt, user = nil, force_default_locale: false)
locale_instructions = custom_locale_instructions(user, force_default_locale)
if locale_instructions
prompt.messages[0][:content] = prompt.messages[0][:content] + locale_instructions
Expand Down Expand Up @@ -128,10 +128,10 @@ def localize_prompt!(prompt, user = nil, force_default_locale = false)
end
end

def generate_prompt(completion_prompt, input, user, force_default_locale = false, &block)
def generate_prompt(completion_prompt, input, user, force_default_locale: false, &block)
llm = helper_llm
prompt = completion_prompt.messages_with_input(input)
localize_prompt!(prompt, user, force_default_locale)
localize_prompt!(prompt, user, force_default_locale: force_default_locale)

llm.generate(
prompt,
Expand All @@ -143,8 +143,14 @@ def generate_prompt(completion_prompt, input, user, force_default_locale = false
)
end

def generate_and_send_prompt(completion_prompt, input, user, force_default_locale = false)
completion_result = generate_prompt(completion_prompt, input, user, force_default_locale)
def generate_and_send_prompt(completion_prompt, input, user, force_default_locale: false)
completion_result =
generate_prompt(
completion_prompt,
input,
user,
force_default_locale: force_default_locale,
)
result = { type: completion_prompt.prompt_type }

result[:suggestions] = (
Expand All @@ -160,24 +166,38 @@ def generate_and_send_prompt(completion_prompt, input, user, force_default_local
result
end

def stream_prompt(completion_prompt, input, user, channel)
def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
pp "stream prompt_called #{input}, #{channel}, #{completion_prompt}, #{user}"
streamed_diff = +""
streamed_result = +""
start = Time.now

generate_prompt(completion_prompt, input, user) do |partial_response, cancel_function|
generate_prompt(
completion_prompt,
input,
user,
force_default_locale: force_default_locale,
) do |partial_response, cancel_function|
streamed_result << partial_response

# Throttle the updates
if (Time.now - start > 0.5) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), done: false }
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?

# Throttle the updates and
# checking length prevents partial tags
# that aren't sanitized correctly yet (i.e. '<output')
# from being sent in the stream
if streamed_result.length > 10 && (Time.now - start > 0.3) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
publish_update(channel, payload, user)
start = Time.now
end
end

final_diff = parse_diff(input, streamed_result) if completion_prompt.diff?

sanitized_result = sanitize_result(streamed_result)
if sanitized_result.present?
publish_update(channel, { result: sanitized_result, done: true }, user)
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
end
end

Expand Down
Loading
Loading