Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 1e2936d

Browse files
committed
REFACTOR: A Simpler way of interacting with embeddings' tables.
This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table. It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions.
1 parent 34f43f3 commit 1e2936d

24 files changed

+321
-75
lines changed

app/jobs/regular/digest_rag_upload.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def execute(args)
1818
target = target_type.constantize.find_by(id: target_id)
1919
return if !target
2020

21-
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
22-
vector_rep =
23-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
21+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
2422

2523
tokenizer = vector_rep.tokenizer
2624
chunk_tokens = target.rag_chunk_tokens

app/jobs/regular/generate_embeddings.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ def execute(args)
1616
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
1717
return if post.raw.blank?
1818

19-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
20-
vector_rep =
21-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
19+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
2220

2321
vector_rep.generate_representation_from(target)
2422
end

app/jobs/regular/generate_rag_embeddings.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ class GenerateRagEmbeddings < ::Jobs::Base
88
def execute(args)
99
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
1010

11-
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
12-
vector_rep =
13-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
11+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
1412

1513
# generate_representation_from checks compares the digest value to make sure
1614
# the embedding is only generated once per fragment unless something changes.

app/jobs/scheduled/embeddings_backfill.rb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ def execute(args)
2020

2121
rebaked = 0
2222

23-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
24-
vector_rep =
25-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
23+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
2624
table_name = vector_rep.topic_table_name
2725

2826
topics =
@@ -41,7 +39,7 @@ def execute(args)
4139
relation = topics.where(<<~SQL).limit(limit - rebaked)
4240
#{table_name}.model_version < #{vector_rep.version}
4341
OR
44-
#{table_name}.strategy_version < #{strategy.version}
42+
#{table_name}.strategy_version < #{vector_rep.strategy_version}
4543
SQL
4644

4745
rebaked += populate_topic_embeddings(vector_rep, relation)

app/models/rag_document_fragment.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ def update_target_uploads(target, upload_ids)
3939
end
4040

4141
def indexing_status(persona, uploads)
42-
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
43-
vector_rep =
44-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
42+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
4543

4644
embeddings_table = vector_rep.rag_fragments_table_name
4745

db/migrate/20240611170905_move_embeddings_to_single_table_per_type.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def up
147147
SQL
148148

149149
begin
150-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
151-
vector_rep =
152-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
150+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
153151
rescue StandardError => e
154152
Rails.logger.error("Failed to index embeddings: #{e}")
155153
end

lib/ai_bot/personas/persona.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
314314

315315
return nil if !consolidated_question
316316

317-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
318-
vector_rep =
319-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
317+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
320318
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
321319

322320
interactions_vector = vector_rep.vector_from(consolidated_question)

lib/ai_bot/tool_runner.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ def rag_search(query, filenames: nil, limit: 10)
141141

142142
return [] if upload_refs.empty?
143143

144-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
145-
vector_rep =
146-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
144+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
147145
query_vector = vector_rep.vector_from(query)
148146
fragment_ids =
149147
vector_rep.asymmetric_rag_fragment_similarity_search(

lib/ai_helper/semantic_categorizer.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ def tags
9292
private
9393

9494
def nearest_neighbors(limit: 100)
95-
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
96-
vector_rep =
97-
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
95+
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
9896

9997
raw_vector = vector_rep.vector_from(@text)
10098

lib/embeddings/schema.rb

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# frozen_string_literal: true
2+
3+
# We don't have AR objects for our embeddings, so this class
4+
# acts as an intermediary between us and the DB.
5+
# It lets us retrieve embeddings either symmetrically and asymmetrically,
6+
# and also store them.
7+
8+
module DiscourseAi
9+
module Embeddings
10+
class Schema
11+
TOPICS_TABLE = "ai_topic_embeddings"
12+
POSTS_TABLE = "ai_post_embeddings"
13+
RAG_DOCS_TABLE = "ai_document_fragment_embeddings"
14+
15+
def self.for(
16+
target_klass,
17+
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
18+
)
19+
case target_klass&.name
20+
when "Topic"
21+
new(TOPICS_TABLE, "topic_id", vector)
22+
when "Post"
23+
new(POSTS_TABLE, "post_id", vector)
24+
when "RagDocumentFragment"
25+
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
26+
else
27+
raise ArgumentError, "Invalid target type for embeddings"
28+
end
29+
end
30+
31+
def initialize(table, target_column, vector)
32+
@table = table
33+
@target_column = target_column
34+
@vector = vector
35+
end
36+
37+
attr_reader :table, :target_column, :vector
38+
39+
def find_by_embedding(embedding)
40+
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
41+
SELECT *
42+
FROM #{table}
43+
WHERE
44+
model_id = :vid AND strategy_id = :vsid
45+
ORDER BY
46+
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
47+
LIMIT 1
48+
SQL
49+
end
50+
51+
def find_by_target(target)
52+
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
53+
SELECT *
54+
FROM #{table}
55+
WHERE
56+
model_id = :vid AND
57+
strategy_id = :vsid AND
58+
#{target_column} = :target_id
59+
LIMIT 1
60+
SQL
61+
end
62+
63+
def asymmetric_similarity_search(embedding, limit:, offset:)
64+
builder = DB.build(<<~SQL)
65+
WITH candidates AS (
66+
SELECT
67+
#{target_column},
68+
embeddings::halfvec(#{dimensions}) AS embeddings
69+
FROM
70+
#{table}
71+
/*join*/
72+
/*where*/
73+
ORDER BY
74+
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
75+
LIMIT :limit * 2
76+
)
77+
SELECT
78+
#{target_column},
79+
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
80+
FROM
81+
candidates
82+
ORDER BY
83+
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
84+
LIMIT :limit
85+
OFFSET :offset
86+
SQL
87+
88+
builder.where(
89+
"model_id = :model_id AND strategy_id = :strategy_id",
90+
model_id: vector.id,
91+
strategy_id: vector.strategy_id,
92+
)
93+
94+
yield(builder) if block_given?
95+
96+
builder.query(query_embedding: embedding, limit: limit, offset: offset)
97+
rescue PG::Error => e
98+
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
99+
raise MissingEmbeddingError
100+
end
101+
102+
def symmetric_similarity_search(record)
103+
builder = DB.build(<<~SQL)
104+
WITH le_target AS (
105+
SELECT
106+
embeddings
107+
FROM
108+
#{table}
109+
WHERE
110+
model_id = :vid AND
111+
strategy_id = :vsid AND
112+
#{target_column} = :target_id
113+
LIMIT 1
114+
)
115+
SELECT #{target_column} FROM (
116+
SELECT
117+
#{target_column}, embeddings
118+
FROM
119+
#{table}
120+
/*join*/
121+
/*where*/
122+
ORDER BY
123+
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
124+
SELECT
125+
binary_quantize(embeddings)::bit(#{dimensions})
126+
FROM
127+
le_target
128+
LIMIT 1
129+
)
130+
LIMIT 200
131+
) AS widenet
132+
ORDER BY
133+
embeddings::halfvec(#{dimensions}) #{pg_function} (
134+
SELECT
135+
embeddings::halfvec(#{dimensions})
136+
FROM
137+
le_target
138+
LIMIT 1
139+
)
140+
LIMIT 100;
141+
SQL
142+
143+
builder.where("model_id = :vid AND strategy_id = :vsid")
144+
145+
yield(builder) if block_given?
146+
147+
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
148+
rescue PG::Error => e
149+
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
150+
raise MissingEmbeddingError
151+
end
152+
153+
def store(record, embedding, digest)
154+
DB.exec(
155+
<<~SQL,
156+
INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
157+
VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
158+
ON CONFLICT (model_id, strategy_id, post_id)
159+
DO UPDATE SET
160+
model_version = :model_version,
161+
strategy_version = :strategy_version,
162+
digest = :digest,
163+
embeddings = '[:embeddings]',
164+
updated_at = :now
165+
SQL
166+
target_id: record.id,
167+
model_id: vector.id,
168+
model_version: vector.version,
169+
strategy_id: vector.strategy_id,
170+
strategy_version: vector.strategy_version,
171+
digest: digest,
172+
embeddings: embedding,
173+
now: Time.zone.now,
174+
)
175+
end
176+
177+
private
178+
179+
delegate :dimensions, :pg_function, to: :vector
180+
end
181+
end
182+
end

0 commit comments

Comments
 (0)