Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 3d30d70

Browse files
committed
DEV: Backfill embeddings concurrently.
We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps: Prepare text HTTP call to retrieve the vector Save to DB. Each one is independently executed on whatever thread the pool gives us. We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
1 parent 690d6e6 commit 3d30d70

File tree

4 files changed

+110
-31
lines changed

4 files changed

+110
-31
lines changed

app/jobs/scheduled/embeddings_backfill.rb

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def execute(args)
7575
# First, we'll try to backfill embeddings for posts that have none
7676
posts
7777
.where("#{table_name}.post_id IS NULL")
78-
.find_each do |t|
79-
vector_rep.generate_representation_from(t)
80-
rebaked += 1
78+
.find_in_batches do |batch|
79+
vector_rep.gen_bulk_reprensentations(batch)
80+
rebaked += batch.size
8181
end
8282

8383
return if rebaked >= limit
@@ -90,24 +90,28 @@ def execute(args)
9090
OR
9191
#{table_name}.strategy_version < #{strategy.version}
9292
SQL
93-
.find_each do |t|
94-
vector_rep.generate_representation_from(t)
95-
rebaked += 1
93+
.find_in_batches do |batch|
94+
vector_rep.gen_bulk_reprensentations(batch)
95+
rebaked += batch.size
9696
end
9797

9898
return if rebaked >= limit
9999

100100
# Finally, we'll try to backfill embeddings for posts that have outdated
101101
# embeddings due to edits. Here we only do 10% of the limit
102-
posts
103-
.where("#{table_name}.updated_at < ?", 7.days.ago)
104-
.order("random()")
105-
.limit((limit - rebaked) / 10)
106-
.pluck(:id)
107-
.each do |id|
108-
vector_rep.generate_representation_from(Post.find_by(id: id))
109-
rebaked += 1
110-
end
102+
posts_batch_size = 1000
103+
104+
outdated_post_ids =
105+
posts
106+
.where("#{table_name}.updated_at < ?", 7.days.ago)
107+
.order("random()")
108+
.limit((limit - rebaked) / 10)
109+
.pluck(:id)
110+
111+
outdated_post_ids.each_slice(posts_batch_size) do |batch|
112+
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
113+
rebaked += batch.length
114+
end
111115

112116
rebaked
113117
end
@@ -120,14 +124,13 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
120124
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
121125

122126
ids = topics.pluck("topics.id")
127+
batch_size = 1000
123128

124-
ids.each do |id|
125-
topic = Topic.find_by(id: id)
126-
if topic
127-
vector_rep.generate_representation_from(topic)
128-
done += 1
129-
end
129+
ids.each_slice(batch_size) do |batch|
130+
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
131+
done += batch.length
130132
end
133+
131134
done
132135
end
133136
end

lib/embeddings/vector_representations/base.rb

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,37 @@ def vector_from(text, asymetric: false)
5050
raise NotImplementedError
5151
end
5252

53+
def gen_bulk_reprensentations(relation)
54+
raw_inputs = relation.map { |record| { target: record, text: prepare_text(record) } }
55+
56+
pool_size = 10
57+
pool =
58+
Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: pool_size, idletime: 30)
59+
60+
embedding_gen = inference_client
61+
db = RailsMultisite::ConnectionManagement.current_db
62+
promised_embeddings =
63+
raw_inputs.map do |raw|
64+
Concurrent::Promises
65+
.fulfilled_future(raw, pool)
66+
.then_on(pool) do |w_prepared_text|
67+
w_prepared_text.merge(
68+
embedding: embedding_gen.perform!(w_prepared_text[:text]),
69+
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
70+
)
71+
end
72+
.then_on(pool) do |w_embedding|
73+
RailsMultisite::ConnectionManagement.with_connection(db) do
74+
save_to_db(w_embedding[:target], w_embedding[:embedding], w_embedding[:digest])
75+
end
76+
end
77+
end
78+
79+
Concurrent::Promises.zip(*promised_embeddings).value!
80+
end
81+
5382
def generate_representation_from(target, persist: true)
54-
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
83+
text = prepare_text(target)
5584
return if text.blank?
5685

5786
target_column =
@@ -429,6 +458,10 @@ def save_to_db(target, vector, digest)
429458
def inference_client
430459
raise NotImplementedError
431460
end
461+
462+
def prepare_text(record)
463+
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
464+
end
432465
end
433466
end
434467
end

lib/embeddings/vector_representations/multilingual_e5_large.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def vector_from(text, asymetric: false)
3434
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
3535
if needs_truncation
3636
text = tokenizer.truncate(text, max_sequence_length - 2)
37-
else
37+
elsif !text.starts_with?("query:")
3838
text = "query: #{text}"
3939
end
4040

@@ -79,6 +79,14 @@ def inference_client
7979
raise "No inference endpoint configured"
8080
end
8181
end
82+
83+
def prepare_text(record)
84+
if inference_client.class.name.include?("DiscourseClassifier")
85+
return "query: #{super(record)}"
86+
end
87+
88+
super(record)
89+
end
8290
end
8391
end
8492
end

spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# frozen_string_literal: true
22

33
RSpec.shared_examples "generates and store embedding using with vector representation" do
4-
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
4+
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
5+
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
56

67
describe "#vector_from" do
78
it "creates a vector from a given string" do
89
text = "This is a piece of text"
9-
stub_vector_mapping(text, @expected_embedding)
10+
stub_vector_mapping(text, expected_embedding_1)
1011

11-
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
12+
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
1213
end
1314
end
1415

@@ -24,11 +25,11 @@
2425
vector_rep.tokenizer,
2526
vector_rep.max_sequence_length - 2,
2627
)
27-
stub_vector_mapping(text, @expected_embedding)
28+
stub_vector_mapping(text, expected_embedding_1)
2829

2930
vector_rep.generate_representation_from(topic)
3031

31-
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
32+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
3233
end
3334

3435
it "creates a vector from a post and stores it in the database" do
@@ -38,11 +39,45 @@
3839
vector_rep.tokenizer,
3940
vector_rep.max_sequence_length - 2,
4041
)
41-
stub_vector_mapping(text, @expected_embedding)
42+
stub_vector_mapping(text, expected_embedding_1)
4243

4344
vector_rep.generate_representation_from(post)
4445

45-
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
46+
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
47+
end
48+
end
49+
50+
describe "#gen_bulk_reprensentations" do
51+
fab!(:topic) { Fabricate(:topic) }
52+
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
53+
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
54+
55+
fab!(:topic_2) { Fabricate(:topic) }
56+
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
57+
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
58+
59+
it "creates a vector for each object in the relation" do
60+
text =
61+
truncation.prepare_text_from(
62+
topic,
63+
vector_rep.tokenizer,
64+
vector_rep.max_sequence_length - 2,
65+
)
66+
67+
text2 =
68+
truncation.prepare_text_from(
69+
topic_2,
70+
vector_rep.tokenizer,
71+
vector_rep.max_sequence_length - 2,
72+
)
73+
74+
stub_vector_mapping(text, expected_embedding_1)
75+
stub_vector_mapping(text2, expected_embedding_2)
76+
77+
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
78+
79+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
80+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
4681
end
4782
end
4883

@@ -58,7 +93,7 @@
5893
vector_rep.tokenizer,
5994
vector_rep.max_sequence_length - 2,
6095
)
61-
stub_vector_mapping(text, @expected_embedding)
96+
stub_vector_mapping(text, expected_embedding_1)
6297
vector_rep.generate_representation_from(topic)
6398

6499
expect(

0 commit comments

Comments
 (0)