Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 7 additions & 36 deletions app/controllers/discourse_ai/admin/ai_personas_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -111,55 +111,26 @@ def stream_reply

topic_id = params[:topic_id].to_i
topic = nil
post = nil

if topic_id > 0
topic = Topic.find(topic_id)

raise Discourse::NotFound if topic.nil?

if topic.topic_allowed_users.where(user_id: user.id).empty?
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
end

post =
PostCreator.create!(
user,
topic_id: topic_id,
raw: params[:query],
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
)
else
post =
PostCreator.create!(
user,
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
raw: params[:query],
archetype: Archetype.private_message,
target_usernames: "#{user.username},#{persona.user.username}",
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
)

topic = post.topic
end

hijack = request.env["rack.hijack"]
io = hijack.call

user = current_user

DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
io,
persona,
user,
topic,
post,
io: io,
persona: persona,
user: user,
topic: topic,
query: params[:query].to_s,
custom_instructions: params[:custom_instructions].to_s,
current_user: current_user,
)
end

Expand Down
5 changes: 5 additions & 0 deletions lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def craft_prompt(context, llm: nil)
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
end

if context[:custom_instructions].present?
prompt_insts << "\n"
prompt_insts << context[:custom_instructions]
end

fragments_guidance =
rag_fragments_prompt(
context[:conversation_context].to_a,
Expand Down
3 changes: 2 additions & 1 deletion lib/ai_bot/playground.rb
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def get_context(participants:, conversation_context:, user:, skip_tool_details:
result
end

def reply_to(post, &blk)
def reply_to(post, custom_instructions: nil, &blk)
# this is a multithreading issue
# post custom prompt is needed and it may not
# be properly loaded, ensure it is loaded
Expand All @@ -413,6 +413,7 @@ def reply_to(post, &blk)
context[:post_id] = post.id
context[:topic_id] = post.topic_id
context[:private_message] = post.topic.private_message?
context[:custom_instructions] = custom_instructions

reply_user = bot.bot_user
if bot.persona.class.respond_to?(:user_id)
Expand Down
37 changes: 32 additions & 5 deletions lib/ai_bot/response_http_streamer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,36 @@ def schedule_block(&block)

# keeping this in a static method so we don't capture ENV and other bits
# this allows us to release memory earlier
def queue_streamed_reply(io, persona, user, topic, post)
def queue_streamed_reply(
io:,
persona:,
user:,
topic:,
query:,
custom_instructions:,
current_user:
)
schedule_block do
begin
post_params = {
raw: query,
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
}

if topic
post_params[:topic_id] = topic.id
else
post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix")
post_params[:archetype] = Archetype.private_message
post_params[:target_usernames] = "#{user.username},#{persona.user.username}"
end

post = PostCreator.create!(user, post_params)
topic = post.topic

io.write "HTTP/1.1 200 OK"
io.write CRLF
io.write "Content-Type: text/plain; charset=utf-8"
Expand All @@ -52,7 +79,7 @@ def queue_streamed_reply(io, persona, user, topic, post)
io.flush

persona_class =
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user)
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)

data =
Expand All @@ -69,7 +96,7 @@ def queue_streamed_reply(io, persona, user, topic, post)

DiscourseAi::AiBot::Playground
.new(bot)
.reply_to(post) do |partial|
.reply_to(post, custom_instructions: custom_instructions) do |partial|
next if partial.length == 0

data = { partial: partial }.to_json + "\n\n"
Expand All @@ -88,11 +115,11 @@ def queue_streamed_reply(io, persona, user, topic, post)
io.write CRLF

io.flush
io.done
io.done if io.respond_to?(:done)
rescue StandardError => e
# make it a tiny bit easier to debug in dev, this is tricky
# multi-threaded code that exhibits various limitations in rails
p e if Rails.env.development?
p e if Rails.env.development? || Rails.env.test?
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
ensure
io.close
Expand Down
10 changes: 9 additions & 1 deletion lib/completions/endpoints/fake.rb
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def self.last_call=(params)
@last_call = params
end

def self.previous_calls
@previous_calls ||= []
end

def self.reset!
@last_call = nil
@fake_content = nil
Expand All @@ -118,7 +122,11 @@ def perform_completion!(
feature_name: nil,
feature_context: nil
)
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call
self.class.previous_calls << last_call
# guard memory in test
self.class.previous_calls.shift if self.class.previous_calls.length > 10

content = self.class.fake_content

Expand Down
6 changes: 6 additions & 0 deletions spec/requests/admin/ai_personas_controller_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def validate_streamed_response(raw_http, expected)
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
default_llm: "custom:#{llm.id}",
allow_personal_messages: true,
system_prompt: "you are a helpful bot",
)

io_out, io_in = IO.pipe
Expand All @@ -510,6 +511,7 @@ def validate_streamed_response(raw_http, expected)
query: "how are you today?",
user_unique_id: "site:test.com:user_id:1",
preferred_username: "test_user",
custom_instructions: "To be appended to system prompt",
},
env: {
"rack.hijack" => lambda { io_in },
Expand All @@ -521,6 +523,10 @@ def validate_streamed_response(raw_http, expected)
raw = io_out.read
context_info = validate_streamed_response(raw, "This is a test! Testing!")

system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content]

expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt")

expect(context_info["topic_id"]).to be_present
topic = Topic.find(context_info["topic_id"])
last_post = topic.posts.order(:created_at).last
Expand Down
Loading