Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 90ec441

Browse files
committed
DEV: refactor bot internals
This introduces a proper object for bot context, this makes it simpler to improve context management as we go cause we have a nice object to work with Starts refactoring allowing for a single message to have multiple uploads throughout
1 parent 1dde82e commit 90ec441

File tree

9 files changed

+142
-54
lines changed

9 files changed

+142
-54
lines changed

lib/ai_bot/bot.rb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def get_updated_title(conversation_context, post, user)
7575
def force_tool_if_needed(prompt, context)
7676
return if prompt.tool_choice == :none
7777

78-
context[:chosen_tools] ||= []
78+
context.chosen_tools ||= []
7979
forced_tools = persona.force_tool_use.map { |tool| tool.name }
80-
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
80+
force_tool = forced_tools.find { |name| !context.chosen_tools.include?(name) }
8181

8282
if force_tool && persona.forced_tool_count > 0
8383
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
8484
force_tool = false if user_turns > persona.forced_tool_count
8585
end
8686

8787
if force_tool
88-
context[:chosen_tools] << force_tool
88+
context.chosen_tools << force_tool
8989
prompt.tool_choice = force_tool
9090
else
9191
prompt.tool_choice = nil
@@ -100,7 +100,7 @@ def reply(context, &update_blk)
100100
ongoing_chain = true
101101
raw_context = []
102102

103-
user = context[:user]
103+
user = context.user
104104

105105
llm_kwargs = { user: user }
106106
llm_kwargs[:temperature] = persona.temperature if persona.temperature
@@ -297,7 +297,7 @@ def process_tool(
297297
end
298298

299299
def invoke_tool(tool, llm, cancel, context, &update_blk)
300-
show_placeholder = !context[:skip_tool_details] && !tool.class.allow_partial_tool_calls?
300+
show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
301301

302302
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
303303

lib/ai_bot/bot_context.rb

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module AiBot
5+
class BotContext
6+
attr_accessor :messages,
7+
:topic_id,
8+
:post_id,
9+
:private_message,
10+
:custom_instructions,
11+
:user,
12+
:skip_tool_details,
13+
:participants,
14+
:chosen_tools
15+
16+
def initialize(
17+
post: nil,
18+
participants: nil,
19+
user: nil,
20+
skip_tool_details: nil,
21+
messages:,
22+
custom_instructions: nil
23+
)
24+
@participants = participants
25+
@user = user
26+
@skip_tool_details = skip_tool_details
27+
@messages = messages
28+
@custom_instructions = custom_instructions
29+
30+
if post
31+
@post_id = post.id
32+
@topic_id = post.topic_id
33+
@private_message = post.topic.private_message?
34+
@participants = post.topic.allowed_users.map(&:username).join(", ") if @private_message
35+
@user = post.user
36+
end
37+
end
38+
39+
# these are strings that can be safely interpolated into templates
40+
TEMPLATE_PARAMS = %w[time site_url site_title site_description]
41+
42+
def lookup_template_param(key)
43+
public_send(key.to_sym) if TEMPLATE_PARAMS.include?(key)
44+
end
45+
46+
def time
47+
@time ||= Time.zone.now
48+
end
49+
50+
def site_url
51+
Discourse.base_url
52+
end
53+
54+
def site_title
55+
SiteSetting.title
56+
end
57+
58+
def site_description
59+
SiteSetting.site_description
60+
end
61+
62+
def private_message?
63+
@private_message
64+
end
65+
end
66+
end
67+
end

lib/ai_bot/personas/persona.rb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def available_tools
163163
def craft_prompt(context, llm: nil)
164164
system_insts =
165165
system_prompt.gsub(/\{(\w+)\}/) do |match|
166-
found = context[match[1..-2].to_sym]
166+
found = context.lookup_template_param(match[1..-2])
167167
found.nil? ? match : found.to_s
168168
end
169169

@@ -180,26 +180,26 @@ def craft_prompt(context, llm: nil)
180180
)
181181
end
182182

183-
if context[:custom_instructions].present?
183+
if context.custom_instructions.present?
184184
prompt_insts << "\n"
185-
prompt_insts << context[:custom_instructions]
185+
prompt_insts << context.custom_instructions
186186
end
187187

188188
fragments_guidance =
189189
rag_fragments_prompt(
190-
context[:conversation_context].to_a,
190+
context.messages,
191191
llm: question_consolidator_llm,
192-
user: context[:user],
192+
user: context.user,
193193
)&.strip
194194

195195
prompt_insts << fragments_guidance if fragments_guidance.present?
196196

197197
prompt =
198198
DiscourseAi::Completions::Prompt.new(
199199
prompt_insts,
200-
messages: context[:conversation_context].to_a,
201-
topic_id: context[:topic_id],
202-
post_id: context[:post_id],
200+
messages: context.messages,
201+
topic_id: context.topic_id,
202+
post_id: context.post_id,
203203
)
204204

205205
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled

lib/ai_bot/playground.rb

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def update_playground_with(post)
227227
schedule_bot_reply(post) if can_attach?(post)
228228
end
229229

230-
def conversation_context(post, style: nil)
230+
def post_prompt_messages(post, style: nil)
231231
# Pay attention to the `post_number <= ?` here.
232232
# We want to inject the last post as context because they are translated differently.
233233

@@ -307,10 +307,10 @@ def conversation_context(post, style: nil)
307307
end
308308

309309
def title_playground(post, user)
310-
context = conversation_context(post)
310+
messages = post_prompt_messages(post)
311311

312312
bot
313-
.get_updated_title(context, post, user)
313+
.get_updated_title(messages, post, user)
314314
.tap do |new_title|
315315
PostRevisor.new(post.topic.first_post, post.topic).revise!(
316316
bot.bot_user,
@@ -326,7 +326,7 @@ def title_playground(post, user)
326326
)
327327
end
328328

329-
def chat_context(message, channel, persona_user, context_post_ids)
329+
def chat_prompt_messages(message, channel, persona_user, context_post_ids)
330330
has_vision = bot.persona.class.vision_enabled
331331
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
332332

@@ -411,9 +411,9 @@ def reply_to_chat_message(message, channel, context_post_ids)
411411
context_post_ids = nil if !channel.direct_message_channel?
412412

413413
context =
414-
get_context(
415-
participants: participants.join(", "),
416-
conversation_context: chat_context(message, channel, persona_user, context_post_ids),
414+
BotContext.new(
415+
participants: participants,
416+
messages: chat_prompt_messages(message, channel, persona_user, context_post_ids),
417417
user: message.user,
418418
skip_tool_details: true,
419419
)
@@ -460,22 +460,6 @@ def reply_to_chat_message(message, channel, context_post_ids)
460460
reply
461461
end
462462

463-
def get_context(participants:, conversation_context:, user:, skip_tool_details: nil)
464-
result = {
465-
site_url: Discourse.base_url,
466-
site_title: SiteSetting.title,
467-
site_description: SiteSetting.site_description,
468-
time: Time.zone.now,
469-
participants: participants,
470-
conversation_context: conversation_context,
471-
user: user,
472-
}
473-
474-
result[:skip_tool_details] = true if skip_tool_details
475-
476-
result
477-
end
478-
479463
def reply_to(
480464
post,
481465
custom_instructions: nil,
@@ -510,15 +494,11 @@ def reply_to(
510494
)
511495

512496
context =
513-
get_context(
514-
participants: post.topic.allowed_users.map(&:username).join(", "),
515-
conversation_context: conversation_context(post, style: context_style),
516-
user: post.user,
497+
BotContext.new(
498+
post: post,
499+
custom_instructions: custom_instructions,
500+
messages: post_prompt_messages(post, style: context_style),
517501
)
518-
context[:post_id] = post.id
519-
context[:topic_id] = post.topic_id
520-
context[:private_message] = post.topic.private_message?
521-
context[:custom_instructions] = custom_instructions
522502

523503
reply_user = bot.bot_user
524504
if bot.persona.class.respond_to?(:user_id)
@@ -562,7 +542,7 @@ def reply_to(
562542
Discourse.redis.setex(redis_stream_key, 60, 1)
563543
end
564544

565-
context[:skip_tool_details] ||= !bot.persona.class.tool_details
545+
context.skip_tool_details ||= !bot.persona.class.tool_details
566546

567547
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
568548

lib/ai_bot/tool_runner.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def attach_upload(mini_racer_context)
457457
UploadCreator.new(
458458
file,
459459
filename,
460-
for_private_message: @context[:private_message],
460+
for_private_message: @context.private_message,
461461
).create_for(@bot_user.id)
462462

463463
{ id: upload.id, short_url: upload.short_url, url: upload.url }

lib/ai_bot/tools/dall_e.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def invoke
111111
UploadCreator.new(
112112
file,
113113
"image.png",
114-
for_private_message: context[:private_message],
114+
for_private_message: context.private_message?,
115115
).create_for(bot_user.id),
116116
}
117117
end

lib/completions/prompt.rb

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,28 @@ def has_tools?
115115
tools.present?
116116
end
117117

118-
# helper method to get base64 encoded uploads
119-
# at the correct dimentions
118+
# TODO migrate to new API
120119
def encoded_uploads(message)
121120
return [] if message[:upload_ids].blank?
122121
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
123122
end
124123

124+
def encode_upload(upload_id)
125+
UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first
126+
end
127+
128+
def content_with_encoded_uploads(content)
129+
return [content] unless content.is_a?(Array)
130+
131+
content.map do |c|
132+
if c.is_a?(Hash) && c.key?(:upload_id)
133+
encode_upload(c[:upload_id])
134+
else
135+
c
136+
end
137+
end
138+
end
139+
125140
def ==(other)
126141
return false unless other.is_a?(Prompt)
127142
messages == other.messages && tools == other.tools && topic_id == other.topic_id &&
@@ -166,8 +181,17 @@ def validate_message(message)
166181
if message[:upload_ids].present? && message[:type] != :user
167182
raise ArgumentError, "upload_ids are only supported for users"
168183
end
169-
170-
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
184+
if message[:content].is_a?(Array)
185+
message[:content].each do |content|
186+
if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id])
187+
raise ArgumentError, "Array message content must be a string or {upload_id: ...} "
188+
end
189+
end
190+
else
191+
if !message[:content].is_a?(String)
192+
raise ArgumentError, "Message content must be a string or an array"
193+
end
194+
end
171195
end
172196

173197
def validate_turn(last_turn, new_turn)

spec/lib/completions/prompt_spec.rb

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@
2525
end
2626

2727
describe "image support" do
28+
it "allows adding uploads inline in messages" do
29+
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
30+
31+
prompt.max_pixels = 300
32+
prompt.push(
33+
type: :user,
34+
content: ["this is an image", { upload_id: upload.id }, "this was an image"],
35+
)
36+
37+
encoded = prompt.content_with_encoded_uploads(prompt.messages.last[:content])
38+
39+
expect(encoded.length).to eq(3)
40+
expect(encoded[0]).to eq("this is an image")
41+
expect(encoded[1][:mime_type]).to eq("image/jpeg")
42+
expect(encoded[2]).to eq("this was an image")
43+
end
44+
2845
it "allows adding uploads to messages" do
2946
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
3047

spec/lib/modules/ai_bot/playground_spec.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@
11551155
end
11561156
end
11571157

1158-
describe "#conversation_context" do
1158+
describe "#post_prompt_messages" do
11591159
context "with limited context" do
11601160
before do
11611161
@old_persona = playground.bot.persona
@@ -1166,7 +1166,7 @@
11661166
after { playground.bot.persona = @old_persona }
11671167

11681168
it "respects max_context_post" do
1169-
context = playground.conversation_context(third_post)
1169+
context = playground.post_prompt_messages(third_post)
11701170

11711171
expect(context).to contain_exactly(
11721172
*[{ type: :user, id: user.username, content: third_post.raw }],
@@ -1215,7 +1215,7 @@
12151215

12161216
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
12171217

1218-
context = playground.conversation_context(third_post)
1218+
context = playground.post_prompt_messages(third_post)
12191219

12201220
expect(context).to contain_exactly(
12211221
*[

0 commit comments

Comments
 (0)