Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 534b0df

Browse files
authored
REFACTOR: Separation of concerns for embedding generation. (#1027)
In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation. The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
1 parent 222e2cf commit 534b0df

36 files changed

+375
-496
lines changed

app/jobs/regular/generate_embeddings.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ def execute(args)
1616
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
1717
return if post.raw.blank?
1818

19-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
20-
21-
vector_rep.generate_representation_from(target)
19+
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
2220
end
2321
end
2422
end

app/jobs/regular/generate_rag_embeddings.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ class GenerateRagEmbeddings < ::Jobs::Base
88
def execute(args)
99
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
1010

11-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
11+
vector = DiscourseAi::Embeddings::Vector.instance
1212

1313
# generate_representation_from checks compares the digest value to make sure
1414
# the embedding is only generated once per fragment unless something changes.
15-
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
15+
fragments.map { |fragment| vector.generate_representation_from(fragment) }
1616

1717
last_fragment = fragments.last
1818
target = last_fragment.target

app/jobs/scheduled/embeddings_backfill.rb

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def execute(args)
2020

2121
rebaked = 0
2222

23-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
23+
vector = DiscourseAi::Embeddings::Vector.instance
24+
vector_def = vector.vdef
2425
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
2526

2627
topics =
@@ -30,19 +31,19 @@ def execute(args)
3031
.where(deleted_at: nil)
3132
.order("topics.bumped_at DESC")
3233

33-
rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
34+
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
3435

3536
return if rebaked >= limit
3637

3738
# Then, we'll try to backfill embeddings for topics that have outdated
3839
# embeddings, be it model or strategy version
3940
relation = topics.where(<<~SQL).limit(limit - rebaked)
40-
#{table_name}.model_version < #{vector_rep.version}
41+
#{table_name}.model_version < #{vector_def.version}
4142
OR
42-
#{table_name}.strategy_version < #{vector_rep.strategy_version}
43+
#{table_name}.strategy_version < #{vector_def.strategy_version}
4344
SQL
4445

45-
rebaked += populate_topic_embeddings(vector_rep, relation)
46+
rebaked += populate_topic_embeddings(vector, relation)
4647

4748
return if rebaked >= limit
4849

@@ -54,7 +55,7 @@ def execute(args)
5455
.where("#{table_name}.updated_at < topics.updated_at")
5556
.limit((limit - rebaked) / 10)
5657

57-
populate_topic_embeddings(vector_rep, relation, force: true)
58+
populate_topic_embeddings(vector, relation, force: true)
5859

5960
return if rebaked >= limit
6061

@@ -76,7 +77,7 @@ def execute(args)
7677
.limit(limit - rebaked)
7778
.pluck(:id)
7879
.each_slice(posts_batch_size) do |batch|
79-
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
80+
vector.gen_bulk_reprensentations(Post.where(id: batch))
8081
rebaked += batch.length
8182
end
8283

@@ -86,14 +87,14 @@ def execute(args)
8687
# embeddings, be it model or strategy version
8788
posts
8889
.where(<<~SQL)
89-
#{table_name}.model_version < #{vector_rep.version}
90+
#{table_name}.model_version < #{vector_def.version}
9091
OR
91-
#{table_name}.strategy_version < #{vector_rep.strategy_version}
92+
#{table_name}.strategy_version < #{vector_def.strategy_version}
9293
SQL
9394
.limit(limit - rebaked)
9495
.pluck(:id)
9596
.each_slice(posts_batch_size) do |batch|
96-
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
97+
vector.gen_bulk_reprensentations(Post.where(id: batch))
9798
rebaked += batch.length
9899
end
99100

@@ -107,7 +108,7 @@ def execute(args)
107108
.limit((limit - rebaked) / 10)
108109
.pluck(:id)
109110
.each_slice(posts_batch_size) do |batch|
110-
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
111+
vector.gen_bulk_reprensentations(Post.where(id: batch))
111112
rebaked += batch.length
112113
end
113114

@@ -116,7 +117,7 @@ def execute(args)
116117

117118
private
118119

119-
def populate_topic_embeddings(vector_rep, topics, force: false)
120+
def populate_topic_embeddings(vector, topics, force: false)
120121
done = 0
121122

122123
topics =
@@ -126,7 +127,7 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
126127
batch_size = 1000
127128

128129
ids.each_slice(batch_size) do |batch|
129-
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
130+
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
130131
done += batch.length
131132
end
132133

lib/ai_bot/personas/persona.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,10 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
314314

315315
return nil if !consolidated_question
316316

317-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
317+
vector = DiscourseAi::Embeddings::Vector.instance
318318
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
319319

320-
interactions_vector = vector_rep.vector_from(consolidated_question)
320+
interactions_vector = vector.vector_from(consolidated_question)
321321

322322
rag_conversation_chunks = self.class.rag_conversation_chunks
323323
search_limit =
@@ -327,7 +327,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
327327
rag_conversation_chunks
328328
end
329329

330-
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
330+
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
331331

332332
candidate_fragment_ids =
333333
schema

lib/ai_bot/tool_runner.rb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,10 @@ def rag_search(query, filenames: nil, limit: 10)
141141

142142
return [] if upload_refs.empty?
143143

144-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
145-
query_vector = vector_rep.vector_from(query)
144+
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
146145
fragment_ids =
147146
DiscourseAi::Embeddings::Schema
148-
.for(RagDocumentFragment, vector: vector_rep)
147+
.for(RagDocumentFragment)
149148
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
150149
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
151150
rag_document_fragments ON

lib/ai_helper/semantic_categorizer.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def tags
9292
private
9393

9494
def nearest_neighbors(limit: 100)
95-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
96-
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
95+
vector = DiscourseAi::Embeddings::Vector.instance
96+
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
9797

98-
raw_vector = vector_rep.vector_from(@text)
98+
raw_vector = vector.vector_from(@text)
9999

100100
muted_category_ids = nil
101101
if @user.present?

lib/embeddings/schema.rb

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,31 @@ class Schema
1414

1515
def self.for(
1616
target_klass,
17-
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
17+
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
1818
)
1919
case target_klass&.name
2020
when "Topic"
21-
new(TOPICS_TABLE, "topic_id", vector)
21+
new(TOPICS_TABLE, "topic_id", vector_def)
2222
when "Post"
23-
new(POSTS_TABLE, "post_id", vector)
23+
new(POSTS_TABLE, "post_id", vector_def)
2424
when "RagDocumentFragment"
25-
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
25+
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
2626
else
2727
raise ArgumentError, "Invalid target type for embeddings"
2828
end
2929
end
3030

31-
def initialize(table, target_column, vector)
31+
def initialize(table, target_column, vector_def)
3232
@table = table
3333
@target_column = target_column
34-
@vector = vector
34+
@vector_def = vector_def
3535
end
3636

37-
attr_reader :table, :target_column, :vector
37+
attr_reader :table, :target_column, :vector_def
3838

3939
def find_by_embedding(embedding)
40-
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
40+
DB.query(
41+
<<~SQL,
4142
SELECT *
4243
FROM #{table}
4344
WHERE
@@ -46,10 +47,15 @@ def find_by_embedding(embedding)
4647
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
4748
LIMIT 1
4849
SQL
50+
query_embedding: embedding,
51+
vid: vector_def.id,
52+
vsid: vector_def.strategy_id,
53+
).first
4954
end
5055

5156
def find_by_target(target)
52-
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
57+
DB.query(
58+
<<~SQL,
5359
SELECT *
5460
FROM #{table}
5561
WHERE
@@ -58,6 +64,10 @@ def find_by_target(target)
5864
#{target_column} = :target_id
5965
LIMIT 1
6066
SQL
67+
target_id: target.id,
68+
vid: vector_def.id,
69+
vsid: vector_def.strategy_id,
70+
).first
6171
end
6272

6373
def asymmetric_similarity_search(embedding, limit:, offset:)
@@ -87,8 +97,8 @@ def asymmetric_similarity_search(embedding, limit:, offset:)
8797

8898
builder.where(
8999
"model_id = :model_id AND strategy_id = :strategy_id",
90-
model_id: vector.id,
91-
strategy_id: vector.strategy_id,
100+
model_id: vector_def.id,
101+
strategy_id: vector_def.strategy_id,
92102
)
93103

94104
yield(builder) if block_given?
@@ -156,7 +166,7 @@ def symmetric_similarity_search(record)
156166

157167
yield(builder) if block_given?
158168

159-
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
169+
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
160170
rescue PG::Error => e
161171
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
162172
raise MissingEmbeddingError
@@ -176,10 +186,10 @@ def store(record, embedding, digest)
176186
updated_at = :now
177187
SQL
178188
target_id: record.id,
179-
model_id: vector.id,
180-
model_version: vector.version,
181-
strategy_id: vector.strategy_id,
182-
strategy_version: vector.strategy_version,
189+
model_id: vector_def.id,
190+
model_version: vector_def.version,
191+
strategy_id: vector_def.strategy_id,
192+
strategy_version: vector_def.strategy_version,
183193
digest: digest,
184194
embeddings: embedding,
185195
now: Time.zone.now,
@@ -188,7 +198,7 @@ def store(record, embedding, digest)
188198

189199
private
190200

191-
delegate :dimensions, :pg_function, to: :vector
201+
delegate :dimensions, :pg_function, to: :vector_def
192202
end
193203
end
194204
end

lib/embeddings/semantic_related.rb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ def self.clear_cache_for(topic)
1313
def related_topic_ids_for(topic)
1414
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
1515

16-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
1716
cache_for = results_ttl(topic)
1817

1918
Discourse
2019
.cache
2120
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
2221
DiscourseAi::Embeddings::Schema
23-
.for(Topic, vector: vector_rep)
22+
.for(Topic)
2423
.symmetric_similarity_search(topic)
2524
.map(&:topic_id)
2625
.tap do |candidate_ids|

lib/embeddings/semantic_search.rb

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def cached_query?(query)
3030
Discourse.cache.read(embedding_key).present?
3131
end
3232

33-
def vector_rep
34-
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
33+
def vector
34+
@vector ||= DiscourseAi::Embeddings::Vector.instance
3535
end
3636

3737
def hyde_embedding(search_term)
@@ -52,16 +52,14 @@ def hyde_embedding(search_term)
5252

5353
Discourse
5454
.cache
55-
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
55+
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
5656
end
5757

5858
def embedding(search_term)
5959
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
6060
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
6161

62-
Discourse
63-
.cache
64-
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
62+
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
6563
end
6664

6765
# this ensures the candidate topics are over selected
@@ -84,7 +82,7 @@ def search_for_topics(query, page = 1, hyde: true)
8482

8583
over_selection_limit = limit * OVER_SELECTION_FACTOR
8684

87-
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
85+
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
8886

8987
candidate_topic_ids =
9088
schema.asymmetric_similarity_search(
@@ -114,7 +112,7 @@ def quick_search(query)
114112

115113
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
116114

117-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
115+
vector = DiscourseAi::Embeddings::Vector.instance
118116

119117
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
120118

@@ -129,12 +127,12 @@ def quick_search(query)
129127
Discourse
130128
.cache
131129
.fetch(embedding_key, expires_in: 1.week) do
132-
vector_rep.vector_from(search_term, asymetric: true)
130+
vector.vector_from(search_term, asymetric: true)
133131
end
134132

135133
candidate_post_ids =
136134
DiscourseAi::Embeddings::Schema
137-
.for(Post, vector: vector_rep)
135+
.for(Post, vector_def: vector.vdef)
138136
.asymmetric_similarity_search(
139137
search_term_embedding,
140138
limit: max_semantic_results_per_page,

lib/embeddings/strategies/truncation.rb

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,28 @@ def version
1212
1
1313
end
1414

15-
def prepare_text_from(target, tokenizer, max_length)
15+
def prepare_target_text(target, vdef)
16+
max_length = vdef.max_sequence_length - 2
17+
1618
case target
1719
when Topic
18-
topic_truncation(target, tokenizer, max_length)
20+
topic_truncation(target, vdef.tokenizer, max_length)
1921
when Post
20-
post_truncation(target, tokenizer, max_length)
22+
post_truncation(target, vdef.tokenizer, max_length)
2123
when RagDocumentFragment
22-
tokenizer.truncate(target.fragment, max_length)
24+
vdef.tokenizer.truncate(target.fragment, max_length)
2325
else
2426
raise ArgumentError, "Invalid target type"
2527
end
2628
end
2729

30+
def prepare_query_text(text, vdef, asymetric: false)
31+
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
32+
max_length = vdef.max_sequence_length - 2
33+
34+
vdef.tokenizer.truncate(text, max_length)
35+
end
36+
2837
private
2938

3039
def topic_information(topic)

0 commit comments

Comments
 (0)