Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 03eccbe

Browse files
authored
FEATURE: Make tool support polymorphic (#798)
Polymorphic RAG means that we will be able to access RAG fragments both from AiPersona and AiCustomTool In turn this gives us support for richer RAG implementations.
1 parent b16390a commit 03eccbe

File tree

15 files changed

+132
-61
lines changed

15 files changed

+132
-61
lines changed

app/controllers/discourse_ai/admin/ai_personas_controller.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def show
4141
def create
4242
ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads))
4343
if ai_persona.save
44-
RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids)
44+
RagDocumentFragment.link_target_and_uploads(ai_persona, attached_upload_ids)
4545

4646
render json: {
4747
ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false),
@@ -59,7 +59,7 @@ def create_user
5959

6060
def update
6161
if @ai_persona.update(ai_persona_params.except(:rag_uploads))
62-
RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids)
62+
RagDocumentFragment.update_target_uploads(@ai_persona, attached_upload_ids)
6363

6464
render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false)
6565
else

app/jobs/regular/digest_rag_upload.rb

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,24 @@ class DigestRagUpload < ::Jobs::Base
99
# TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads.
1010
def execute(args)
1111
return if (upload = Upload.find_by(id: args[:upload_id])).nil?
12-
return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil?
12+
13+
target_type = args[:target_type]
14+
target_id = args[:target_id]
15+
16+
return if !target_type || !target_id
17+
18+
target = target_type.constantize.find_by(id: target_id)
19+
return if !target
1320

1421
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
1522
vector_rep =
1623
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
1724

1825
tokenizer = vector_rep.tokenizer
19-
chunk_tokens = ai_persona.rag_chunk_tokens
20-
overlap_tokens = ai_persona.rag_chunk_overlap_tokens
26+
chunk_tokens = target.rag_chunk_tokens
27+
overlap_tokens = target.rag_chunk_overlap_tokens
2128

22-
fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id)
29+
fragment_ids = RagDocumentFragment.where(target: target, upload: upload).pluck(:id)
2330

2431
# Check if this is the first time we process this upload.
2532
if fragment_ids.empty?
@@ -39,7 +46,7 @@ def execute(args)
3946
overlap_tokens: overlap_tokens,
4047
) do |chunk, metadata|
4148
fragment_ids << RagDocumentFragment.create!(
42-
ai_persona: ai_persona,
49+
target: target,
4350
fragment: chunk,
4451
fragment_number: idx + 1,
4552
upload: upload,

app/jobs/regular/generate_rag_embeddings.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ def execute(args)
1717
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
1818

1919
last_fragment = fragments.last
20-
ai_persona = last_fragment.ai_persona
20+
target = last_fragment.target
2121
upload = last_fragment.upload
2222

23-
indexing_status = RagDocumentFragment.indexing_status(ai_persona, [upload])[upload.id]
23+
indexing_status = RagDocumentFragment.indexing_status(target, [upload])[upload.id]
2424
RagDocumentFragment.publish_status(upload, indexing_status)
2525
end
2626
end

app/models/ai_persona.rb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,15 @@ class AiPersona < ActiveRecord::Base
2020
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
2121
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
2222
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
23+
has_many :rag_document_fragments, dependent: :destroy, as: :target
2324

2425
belongs_to :created_by, class_name: "User"
2526
belongs_to :user
2627

2728
has_many :upload_references, as: :target, dependent: :destroy
2829
has_many :uploads, through: :upload_references
2930

30-
has_many :rag_document_fragment, dependent: :destroy
31-
32-
has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments
33-
3431
before_destroy :ensure_not_system
35-
3632
before_update :regenerate_rag_fragments
3733

3834
def self.persona_cache
@@ -230,7 +226,7 @@ def create_user!
230226

231227
def regenerate_rag_fragments
232228
if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed?
233-
RagDocumentFragment.where(ai_persona: self).delete_all
229+
RagDocumentFragment.where(target: self).delete_all
234230
end
235231
end
236232

app/models/rag_document_fragment.rb

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,40 @@
11
# frozen_string_literal: true
22

33
class RagDocumentFragment < ActiveRecord::Base
4+
# TODO Jan 2025 - remove
5+
self.ignored_columns = %i[ai_persona_id]
6+
47
belongs_to :upload
5-
belongs_to :ai_persona
8+
belongs_to :target, polymorphic: true
69

710
class << self
8-
def link_persona_and_uploads(persona, upload_ids)
9-
return if persona.blank?
11+
def link_target_and_uploads(target, upload_ids)
12+
return if target.blank?
1013
return if upload_ids.blank?
1114
return if !SiteSetting.ai_embeddings_enabled?
1215

13-
UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona)
16+
UploadReference.ensure_exist!(upload_ids: upload_ids, target: target)
1417

1518
upload_ids.each do |upload_id|
16-
Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id)
19+
Jobs.enqueue(
20+
:digest_rag_upload,
21+
target_id: target.id,
22+
target_type: target.class.to_s,
23+
upload_id: upload_id,
24+
)
1725
end
1826
end
1927

20-
def update_persona_uploads(persona, upload_ids)
21-
return if persona.blank?
28+
def update_target_uploads(target, upload_ids)
29+
return if target.blank?
2230
return if !SiteSetting.ai_embeddings_enabled?
2331

2432
if upload_ids.blank?
25-
RagDocumentFragment.where(ai_persona: persona).destroy_all
26-
UploadReference.where(target: persona).destroy_all
33+
RagDocumentFragment.where(target: target).destroy_all
34+
UploadReference.where(target: target).destroy_all
2735
else
28-
RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all
29-
link_persona_and_uploads(persona, upload_ids)
36+
RagDocumentFragment.where(target: target).where.not(upload_id: upload_ids).destroy_all
37+
link_target_and_uploads(target, upload_ids)
3038
end
3139
end
3240

@@ -37,18 +45,25 @@ def indexing_status(persona, uploads)
3745

3846
embeddings_table = vector_rep.rag_fragments_table_name
3947

40-
results = DB.query(<<~SQL, persona_id: persona.id, upload_ids: uploads.map(&:id))
48+
results =
49+
DB.query(
50+
<<~SQL,
4151
SELECT
4252
uploads.id,
4353
SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total,
4454
SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed,
4555
SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left
4656
FROM uploads
47-
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.ai_persona_id = :persona_id
57+
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.target_id = :target_id
58+
AND rdf.target_type = :target_type
4859
LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id
4960
WHERE uploads.id IN (:upload_ids)
5061
GROUP BY uploads.id
5162
SQL
63+
target_id: persona.id,
64+
target_type: persona.class.to_s,
65+
upload_ids: uploads.map(&:id),
66+
)
5267

5368
results.reduce({}) do |acc, r|
5469
acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left }
@@ -78,4 +93,10 @@ def publish_status(upload, status)
7893
# created_at :datetime not null
7994
# updated_at :datetime not null
8095
# metadata :text
96+
# target_id :integer
97+
# target_type :string(800)
98+
#
99+
# Indexes
100+
#
101+
# index_rag_document_fragments_on_target_type_and_target_id (target_type,target_id)
81102
#
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# frozen_string_literal: true
2+
3+
class AddTargetToRagDocumentFragment < ActiveRecord::Migration[7.1]
4+
def change
5+
add_column :rag_document_fragments, :target_id, :integer, null: true
6+
add_column :rag_document_fragments, :target_type, :string, limit: 800, null: true
7+
add_index :rag_document_fragments, %i[target_type target_id]
8+
end
9+
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# frozen_string_literal: true
2+
class DropPersonaIdFromRagDocumentFragments < ActiveRecord::Migration[7.1]
3+
def change
4+
execute <<~SQL
5+
UPDATE rag_document_fragments
6+
SET
7+
target_type = 'AiPersona',
8+
target_id = ai_persona_id
9+
WHERE ai_persona_id IS NOT NULL
10+
SQL
11+
12+
# unlikely but lets be safe
13+
execute <<~SQL
14+
DELETE FROM rag_document_fragments
15+
WHERE target_id IS NULL OR target_type IS NULL
16+
SQL
17+
18+
remove_column :rag_document_fragments, :ai_persona_id
19+
change_column_null :rag_document_fragments, :target_id, false
20+
change_column_null :rag_document_fragments, :target_type, false
21+
end
22+
end

lib/ai_bot/personas/persona.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
288288
candidate_fragment_ids =
289289
vector_rep.asymmetric_rag_fragment_similarity_search(
290290
interactions_vector,
291-
persona_id: id,
291+
target_type: "AiPersona",
292+
target_id: id,
292293
limit:
293294
(
294295
if reranker.reranker_configured?

lib/embeddings/vector_representations/base.rb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_dista
280280

281281
def asymmetric_rag_fragment_similarity_search(
282282
raw_vector,
283-
persona_id:,
283+
target_id:,
284+
target_type:,
284285
limit:,
285286
offset:,
286287
return_distance: false
@@ -299,14 +300,16 @@ def asymmetric_rag_fragment_similarity_search(
299300
WHERE
300301
model_id = #{id} AND
301302
strategy_id = #{@strategy.id} AND
302-
rdf.ai_persona_id = :persona_id
303+
rdf.target_id = :target_id AND
304+
rdf.target_type = :target_type
303305
ORDER BY
304306
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
305307
LIMIT :limit
306308
OFFSET :offset
307309
SQL
308310
query_embedding: raw_vector,
309-
persona_id: persona_id,
311+
target_id: target_id,
312+
target_type: target_type,
310313
limit: limit,
311314
offset: offset,
312315
)

spec/jobs/regular/digest_rag_upload_spec.rb

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@
4141
# be explicit here about chunking strategy
4242
persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10)
4343

44-
described_class.new.execute(upload_id: upload_with_metadata.id, ai_persona_id: persona.id)
44+
described_class.new.execute(
45+
upload_id: upload_with_metadata.id,
46+
target_id: persona.id,
47+
target_type: persona.class.to_s,
48+
)
4549

4650
parsed = +""
4751
first = true
@@ -66,7 +70,11 @@
6670
before { File.expects(:open).returns(document_file) }
6771

6872
it "splits an upload into chunks" do
69-
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
73+
subject.execute(
74+
upload_id: upload.id,
75+
target_id: persona.id,
76+
target_type: persona.class.to_s,
77+
)
7078

7179
created_fragment = RagDocumentFragment.last
7280

@@ -76,19 +84,23 @@
7684
end
7785

7886
it "queue jobs to generate embeddings for each fragment" do
79-
expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change(
80-
Jobs::GenerateRagEmbeddings.jobs,
81-
:size,
82-
).by(1)
87+
expect {
88+
subject.execute(
89+
upload_id: upload.id,
90+
target_id: persona.id,
91+
target_type: persona.class.to_s,
92+
)
93+
}.to change(Jobs::GenerateRagEmbeddings.jobs, :size).by(1)
8394
end
8495
end
8596

8697
it "doesn't generate new fragments if we already processed the upload" do
87-
Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona)
88-
previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
98+
Fabricate(:rag_document_fragment, upload: upload, target: persona)
8999

90-
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
91-
updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
100+
previous_count = RagDocumentFragment.where(upload: upload, target: persona).count
101+
102+
subject.execute(upload_id: upload.id, target_id: persona.id, target_type: persona.class.to_s)
103+
updated_count = RagDocumentFragment.where(upload: upload, target: persona).count
92104

93105
expect(updated_count).to eq(previous_count)
94106
end

0 commit comments

Comments
 (0)