diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 6af4b10fc..ce390d8e3 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -110,7 +110,7 @@ def tool_msg(msg) end def user_msg(msg) - content = prompt.text_only(msg) + content = DiscourseAi::Completions::Prompt.text_only(msg) user_message = { role: "USER", message: content } user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id] user_message diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb index 9dc880973..10098c267 100644 --- a/lib/completions/dialects/nova.rb +++ b/lib/completions/dialects/nova.rb @@ -156,7 +156,7 @@ def user_msg(msg) end end - { role: "user", content: prompt.text_only(msg), images: images } + { role: "user", content: DiscourseAi::Completions::Prompt.text_only(msg), images: images } end def model_msg(msg) diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index fe31bc1db..60d58455c 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -69,7 +69,7 @@ def enable_native_tool? end def user_msg(msg) - user_message = { role: "user", content: prompt.text_only(msg) } + user_message = { role: "user", content: DiscourseAi::Completions::Prompt.text_only(msg) } encoded_uploads = prompt.encoded_uploads(msg) if encoded_uploads.present? diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 8810257c8..0641b64f1 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -8,6 +8,14 @@ class Prompt attr_reader :messages, :tools, :system_message_text attr_accessor :topic_id, :post_id, :max_pixels, :tool_choice + def self.text_only(message) + if message[:content].is_a?(Array) + message[:content].map { |element| element if element.is_a?(String) }.compact.join + else + message[:content] + end + end + def initialize( system_message_text = nil, messages: [], @@ -146,14 +154,6 @@ def encoded_uploads(message) [] end - def text_only(message) - if message[:content].is_a?(Array) - message[:content].map { |element| element if element.is_a?(String) }.compact.join - else - message[:content] - end - end - def encode_upload(upload_id) UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first end diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index d1483f0af..ff9affcc8 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -365,7 +365,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:) # first response if latest_interactions.length == 1 - consolidated_question = latest_interactions[0][:content] + consolidated_question = DiscourseAi::Completions::Prompt.text_only(latest_interactions[0]) else consolidated_question = DiscourseAi::Personas::QuestionConsolidator.consolidate_question( diff --git a/lib/personas/question_consolidator.rb b/lib/personas/question_consolidator.rb index f1e0c476b..d89716a53 100644 --- a/lib/personas/question_consolidator.rb +++ b/lib/personas/question_consolidator.rb @@ -33,7 +33,7 @@ def revised_prompt row = +"" row << ((message[:type] == :user) ? "user" : "model") - content = message[:content] + content = DiscourseAi::Completions::Prompt.text_only(message) current_tokens = @llm.tokenizer.tokenize(content).length allowed_tokens = @max_tokens - tokens diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb index fe310ef87..6c2266772 100644 --- a/spec/lib/personas/persona_spec.rb +++ b/spec/lib/personas/persona_spec.rb @@ -306,21 +306,24 @@ def system_prompt fab!(:llm_model) { Fabricate(:fake_model) } - it "will run the question consolidator" do + fab!(:custom_ai_persona) do + Fabricate( + :ai_persona, + name: "custom", + rag_conversation_chunks: 3, + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + question_consolidator_llm_id: llm_model.id, + ) + end + + before do context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) } EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding) - custom_ai_persona = - Fabricate( - :ai_persona, - name: "custom", - rag_conversation_chunks: 3, - allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], - question_consolidator_llm_id: llm_model.id, - ) - UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) + end + it "will run the question consolidator" do custom_persona = DiscourseAi::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new @@ -343,6 +346,36 @@ def system_prompt expect(message).to include("the time is 1") expect(message).to include("in france?") end + + context "when there are messages with uploads" do + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + let(:image_upload) do + UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + end + + it "the question consolidator works" do + custom_persona = + DiscourseAi::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new + + context.messages = [ + { content: "Tell me the time", type: :user }, + { content: "the time is 1", type: :model }, + { content: ["in france?", { upload_id: image_upload.id }], type: :user }, + ] + + DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do + custom_persona.craft_prompt(context).messages.first[:content] + end + + message = + DiscourseAi::Completions::Endpoints::Fake.last_call[:dialect].prompt.messages.last[ + :content + ] + expect(message).to include("Tell me the time") + expect(message).to include("the time is 1") + expect(message).to include("in france?") + end + end end context "when a persona has RAG uploads" do